Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux,
    ComponentArrays,
    LineSearches,
    OrdinaryDiffEqLowOrderRK,
    Optimization,
    OptimizationOptimJL,
    Printf,
    Random,
    SciMLSensitivity
using CairoMakie

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio  1 "mass_ratio must be <= 1"
    @assert mass_ratio  0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in Schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defining a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but in case you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(
    Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[0.0001513677; 0.00018915694; 7.710529f-5; -8.100904f-5; -3.4554858f-5; -6.137269f-5; 9.628865f-6; 5.067023f-5; 0.00014194798; -0.00012489685; 0.000119017066; 0.00010201956; 6.250788f-5; -3.4920948f-5; 2.6304271f-5; 9.691579f-6; -5.6737285f-6; 0.00015087373; -2.170499f-5; 3.5407207f-5; 8.496613f-5; -7.817986f-5; 2.938672f-5; -2.114699f-5; -0.00015802775; -5.6834804f-5; 8.889765f-5; 5.7103025f-5; -8.753083f-6; -0.00012830622; -9.810521f-5; -0.0001792211;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-0.00014670115 -3.1562136f-5 4.6755034f-5 -0.00013737669 -3.8463837f-5 0.00011173407 -1.2783771f-5 0.0002548873 -0.00012667257 7.389935f-6 6.1385545f-5 5.2868632f-5 7.00988f-5 5.3468575f-5 -4.5367422f-5 0.000108431654 0.0001147943 -8.389296f-5 -2.025833f-5 -8.015727f-5 7.367238f-6 -3.560211f-5 0.00011899552 9.045202f-5 0.00011893503 -7.6827986f-5 9.736617f-5 0.00011965056 0.00017070293 5.267619f-5 -2.2739096f-5 3.8519982f-5; 8.878708f-5 -4.2997763f-5 4.2911088f-5 1.44008945f-5 -0.00014910761 -4.5071763f-5 0.0001082043 6.593332f-5 7.273428f-5 -6.56971f-5 5.2434834f-6 -0.00015571933 -4.7275447f-5 7.2396188f-6 7.930404f-5 4.9434148f-6 2.083318f-5 8.796782f-5 1.9348343f-5 -4.1246905f-5 -3.060539f-5 -5.7305304f-5 0.000154318 -6.839073f-5 -0.0001242132 1.1914611f-5 0.00026394863 7.285609f-5 9.029352f-5 -1.4622515f-5 -0.00012704151 6.01372f-5; 2.9413426f-5 -1.7574699f-5 -0.00016593013 0.00012469075 2.3947397f-5 3.225726f-5 0.00014268055 0.00019643952 0.0001223142 0.00013731256 -2.8599517f-5 3.686709f-5 1.6281965f-5 -3.627554f-5 -5.4595308f-5 8.061459f-5 0.00011920558 -0.00014793004 -2.1221129f-5 2.3573451f-5 -4.458855f-5 -6.2603576f-5 -0.00017415034 0.00013470116 1.0735284f-5 2.2797087f-5 -9.46042f-5 4.7228063f-5 7.880692f-5 0.00019094648 -0.00010571477 -7.316062f-7; 0.00024384915 -9.719367f-5 0.00014937214 0.000107080006 -5.3318123f-5 0.00018166717 -3.7425958f-5 -4.080032f-5 -8.678032f-5 2.3219098f-5 -8.8150315f-5 6.729476f-5 -0.0001250909 0.00012272509 0.0001516958 1.4442077f-5 -1.6191298f-5 0.00013577512 -9.104566f-5 -0.0001056843 0.00011836971 -4.4466866f-5 -7.6382115f-5 -2.5522595f-5 -9.195963f-5 0.00016770982 9.557052f-5 -9.2411574f-5 -1.2784901f-5 -8.762004f-5 2.0152429f-5 8.2867606f-5; 0.00013516082 -8.261657f-6 8.427978f-5 -0.00015172818 -8.4008454f-5 7.2670773f-6 3.5761193f-5 -6.1043465f-7 0.00015104232 -0.00020706923 0.00014318238 -4.611589f-5 -0.000117455085 -2.6283282f-5 9.380495f-5 6.1170584f-5 1.7557692f-5 0.00015907727 6.684908f-5 0.00014917157 -3.14919f-5 0.00010615063 -0.00012780588 -0.00010665438 -0.0001381224 -8.407012f-5 1.0637149f-5 -0.000113843256 -5.6318117f-5 -0.00011035558 3.5857585f-5 0.00018801338; -0.00013861213 -0.0001326961 -5.0474366f-5 4.1309522f-5 -3.8840853f-6 0.00011694175 1.8540015f-5 1.6460192f-5 -0.00023897046 -8.603838f-5 -2.5170784f-5 -0.00011072696 -2.9836594f-5 0.00012050835 -4.601713f-5 7.35464f-5 2.8435486f-5 -6.705148f-6 2.9340814f-5 -3.933143f-6 0.00014103077 -1.0595498f-5 -5.238223f-5 1.9599179f-7 -5.702303f-5 -1.1091327f-5 9.65953f-6 -7.999925f-5 7.658932f-5 1.8510182f-5 3.52845f-5 0.00015348113; -6.225058f-5 6.970909f-5 6.868724f-5 5.6735873f-5 -9.591211f-5 -5.123229f-6 7.546469f-5 -0.00016442413 -5.9282123f-5 4.0096733f-5 0.0001375717 9.3773204f-5 0.00012281605 -4.4607248f-5 -1.6032658f-5 -8.922272f-5 0.00013000236 -5.486839f-5 -4.1074047f-5 0.00017482259 9.655628f-6 -8.315119f-5 -1.2675111f-5 3.765497f-5 7.6707685f-5 9.782407f-5 -9.68511f-5 -1.9089828f-5 5.2315147f-5 0.00013630521 -6.779798f-5 7.1473455f-6; 8.515057f-5 -2.0906355f-6 -2.1783492f-8 -7.451002f-6 2.9838528f-5 -1.3920647f-5 3.4929395f-5 -4.3710217f-5 5.8975627f-5 6.8646565f-5 -8.776075f-6 -4.6035337f-5 7.2849136f-5 6.914177f-5 0.0002326442 -2.7019567f-5 -3.615519f-5 -0.00010111279 -8.1267244f-5 -0.00015440152 1.65889f-5 0.0001901506 -0.00011920705 0.000106005755 -8.643109f-5 -8.967598f-5 -0.00011112181 -3.548173f-5 -8.2298866f-5 0.00022909306 -0.00020745567 5.0288723f-5; -1.6719247f-5 1.4524906f-5 0.00016850722 5.7224788f-5 -0.00011187244 -1.1121108f-5 1.12784155f-5 -4.121019f-5 4.913626f-5 9.07716f-5 9.855194f-5 5.414876f-5 0.00016846138 2.1036616f-5 5.484972f-5 -0.00023485084 -5.99279f-6 0.000114500486 3.6053494f-5 -0.0002557504 5.674277f-5 1.827658f-5 3.2130574f-5 5.9672933f-5 0.00025442493 9.878095f-5 -2.636642f-5 -0.00018813102 -7.3138755f-5 -0.00019451683 -1.2139295f-5 -5.01993f-5; 1.4552284f-5 -4.714828f-5 -4.049922f-5 0.00012129151 3.4230186f-5 5.5112556f-5 0.00011422407 3.7858856f-5 -0.00014463687 -6.351467f-5 7.517614f-5 0.00012221924 1.5788308f-5 3.828693f-5 -0.00020482566 0.00020159129 5.0117483f-6 3.0383628f-5 -9.445336f-5 9.1518894f-5 0.00013424462 -9.110165f-5 1.4682696f-5 -0.00010873174 -2.2535534f-5 -8.357446f-5 0.000110628454 -0.00013073899 -2.549662f-6 -0.00011776637 -2.850938f-5 3.95913f-5; 0.00022516152 0.0001200711 0.00012264591 0.000106753636 2.8523911f-5 -5.9201142f-5 -0.00013068611 4.9654245f-5 6.150765f-6 5.784195f-5 -0.00011585103 -2.9194116f-5 0.00014648415 -0.0001275291 0.00024730028 0.0001265703 0.00013339963 -0.00021233055 8.691315f-5 7.476955f-5 1.571104f-5 7.825008f-5 0.00013719797 8.257209f-5 0.00011741067 -0.00014725038 9.464119f-5 7.499745f-5 -0.00012541421 -6.6527704f-5 -2.7303304f-5 0.00017512857; -0.00010110084 0.00019497496 2.0357817f-5 -9.0251335f-5 -0.0001936088 6.1880295f-5 9.369697f-5 0.00010975892 -0.00014757646 -6.367403f-5 -4.958376f-5 1.2241632f-5 8.1156f-6 -5.8143785f-5 7.550247f-5 -5.7363883f-5 -0.00010391088 6.746836f-5 -3.6211724f-5 0.00017369988 7.5660246f-6 -7.089533f-5 -0.00010272447 0.00010812496 -3.9323484f-5 0.00018973385 -0.00015507471 -7.401561f-5 -2.6138317f-5 -0.00012526764 -1.9822573f-5 0.00010637567; -3.3821507f-5 -6.0370494f-5 -6.349219f-5 3.912108f-5 -2.8333825f-5 4.111202f-5 -2.3923425f-5 4.8510898f-5 8.808348f-5 -0.000114943396 5.7825666f-5 -6.402022f-5 -9.5298084f-5 -0.00012547198 1.9831317f-5 0.00016823737 6.122679f-5 -7.4128555f-5 0.00010079951 -1.7591754f-5 -5.768521f-5 6.24763f-7 -6.742748f-5 1.2204229f-5 -0.00016545855 -2.156792f-5 3.7653062f-5 5.2166433f-5 0.0001532002 -2.1181615f-5 -3.9049773f-5 2.4227918f-5; 0.00013145996 4.111283f-5 -9.779278f-5 -0.0001075436 -3.4918325f-5 -0.00013122096 -2.8631855f-5 -4.189165f-5 -3.3932014f-5 -0.00012430696 -5.213516f-6 9.699098f-5 0.00031222185 0.0002528282 -8.366024f-5 0.00015666407 1.5132436f-5 -6.465772f-5 0.00023015354 3.7683158f-6 2.6428565f-5 8.342638f-5 1.8304732f-5 -0.00012355618 -8.957665f-5 0.00015540532 -0.00027022092 7.608295f-5 -7.3925266f-6 9.711566f-5 8.776839f-5 9.536507f-5; -4.7980634f-6 0.00017457352 -2.7113938f-5 1.3730527f-5 -4.337065f-5 -7.308925f-5 5.8918442f-5 4.0534822f-5 -4.4319295f-5 6.7801455f-5 6.821064f-7 -5.2756437f-5 -3.331549f-5 -1.2188453f-5 -8.44365f-5 5.4042815f-5 -1.1305361f-5 0.00016811279 0.00016254645 7.111442f-5 1.8924615f-5 -2.9761808f-5 -1.5722038f-5 8.715754f-5 -2.8223321f-5 -0.00013831556 -2.934744f-5 9.7530676f-5 6.148341f-5 -0.00011644177 -0.00022698866 -2.6884786f-6; 1.4361872f-5 8.368537f-5 -6.2217616f-5 0.00015794067 6.762896f-5 0.00016877249 3.975084f-5 -0.00015008559 1.0775017f-5 -4.41127f-5 1.721865f-5 7.529273f-5 -9.235939f-5 -0.00020597893 -3.3173983f-5 9.2320715f-6 0.00010322759 8.4517145f-5 -0.00013205026 -0.00017527983 -2.7656795f-5 0.00010319713 -5.431596f-5 9.588362f-5 8.863711f-5 2.1111207f-5 -7.9605474f-5 -0.000116418836 4.7570367f-5 0.00010940163 1.5451134f-5 -0.00015858501; -0.000100975696 -4.706443f-5 4.349108f-5 -1.5211025f-5 0.00020972341 6.0945727f-6 -0.000115840485 3.9133727f-5 1.8962368f-5 0.0001337197 3.9860184f-5 -0.00010023575 -6.867299f-6 4.6392226f-5 0.00025507205 -7.793379f-5 -5.905416f-5 0.000101436875 0.00013414246 5.0954797f-5 3.30386f-5 -7.667922f-5 3.0797132f-6 0.00011646412 -1.6553047f-5 -6.104633f-5 8.030591f-5 -2.5543264f-5 -0.00017780981 -4.2430667f-5 -2.1381236f-5 -0.00011474265; -7.529719f-7 0.00012716521 4.5013276f-5 -1.5217192f-5 -5.9103724f-5 0.00011544879 -9.4039904f-5 6.0515027f-5 -2.295613f-5 -1.1065727f-5 -0.00014319099 0.00011402137 0.00013570508 1.334092f-5 5.81997f-6 -2.1908418f-5 -6.663524f-5 0.00014071024 -0.00017402302 8.294258f-7 5.7088444f-5 -0.00011947913 -6.613895f-5 -0.00010017691 2.0177922f-5 -2.331573f-5 -8.0581114f-5 2.2147615f-6 -7.942783f-5 -4.3200827f-5 -5.700625f-5 -1.579911f-5; -7.2428986f-5 1.2364465f-6 3.294771f-5 -7.5986085f-5 2.0823027f-5 -4.7144844f-5 2.2866387f-5 -7.0600705f-5 -5.046051f-6 0.00015961689 -9.139466f-5 3.129342f-5 -0.00012493518 -7.4743075f-5 -1.3715744f-5 0.00016584307 -7.423483f-5 9.143461f-5 -4.603196f-5 -3.5462876f-5 9.095226f-5 -2.2970607f-5 8.097345f-5 -0.000114065464 0.0001099024 -0.00022359997 3.0645565f-6 -0.00015064789 -1.6533049f-6 0.00013214462 -9.064665f-6 -6.234769f-5; 0.00016490389 -4.5384022f-5 1.0166787f-5 0.000110555855 -9.906058f-5 8.878647f-5 -1.00759025f-5 -7.185264f-5 -7.0030015f-5 2.3314149f-5 0.00016216584 -0.00012517942 -8.927f-5 -6.82787f-5 -0.00012155878 3.1689193f-5 2.8990258f-5 -3.491515f-5 -4.277549f-5 0.00016101898 -7.768926f-5 -5.3585954f-5 -4.0922736f-5 -9.9781006f-5 -5.2190397f-5 -8.386572f-5 0.00012216391 6.246569f-5 2.6801836f-5 3.412644f-5 -0.00012686178 2.023493f-6; 1.619283f-5 -0.000107150365 -1.3164305f-6 -3.7236587f-5 0.00010012421 6.97915f-5 -5.255718f-5 2.7940996f-5 7.918438f-6 2.5067946f-5 3.854037f-5 0.00014385219 -6.468716f-5 0.00014934498 0.00014246925 0.0001942224 0.00010170206 -1.0071727f-5 1.3559998f-5 -2.2907521f-5 2.0824158f-5 8.3106126f-5 -3.6005522f-6 -7.363514f-5 6.590196f-5 8.9784546f-5 0.00013432052 0.000100635916 0.00014314623 -4.1760857f-5 -3.2523327f-5 1.4351365f-5; -3.359051f-5 -3.055933f-5 4.975856f-5 9.951406f-6 6.0151564f-5 9.26924f-6 8.567208f-5 6.027107f-5 -9.057484f-5 0.00016618721 0.00010147432 6.691745f-5 -5.4613676f-5 -0.00013784601 -0.00027719754 7.052993f-5 0.0001941064 2.07873f-5 -0.0001577997 0.00016459487 5.156228f-5 -4.4947643f-5 -0.00017503179 -5.3842723f-5 -7.678298f-5 -8.630341f-5 -3.0373692f-5 3.257268f-5 3.673096f-5 8.8847046f-5 4.3069733f-5 0.00015332723; 9.6588745f-5 -6.4615015f-5 -8.922941f-5 -2.0125988f-6 -5.6676076f-6 9.578341f-5 0.000110170746 3.203074f-5 -3.14685f-5 -2.7044562f-6 0.00018768724 6.614464f-5 -0.0001646646 -0.00015548717 -0.000101382706 1.2570278f-5 1.5305724f-5 1.5759808f-5 -5.268783f-5 7.728711f-6 0.00012465782 4.1416017f-5 -3.2673983f-5 3.0483574f-5 -0.00010750476 0.00010057767 5.0390525f-5 -8.587575f-5 -0.00019245455 -8.921298f-6 9.1628186f-5 -6.0486534f-5; -1.8747805f-5 9.835252f-5 9.900459f-5 3.855449f-7 -1.1450927f-6 -1.2842156f-6 0.0001505228 -2.9965426f-5 0.00012981167 -2.691968f-5 -2.0968877f-5 9.74793f-5 -0.00020526157 -1.3557011f-5 -3.4028108f-6 1.5976726f-5 -0.00015057733 -5.721396f-5 0.000184747 4.1546653f-5 0.00016613345 3.5390618f-5 -0.0001842597 0.00015946107 -2.582376f-5 2.6818327f-5 0.00011018058 2.0562411f-5 -0.00013391716 0.00010466569 -9.4221985f-5 -0.00010792875; -9.811993f-5 -8.236139f-5 -2.8603143f-5 -1.1468011f-5 3.1248393f-5 -3.8167385f-5 -3.2209387f-5 3.9021215f-5 9.607331f-5 2.4143246f-5 -0.00013088451 7.192475f-5 -4.4795254f-7 -5.5544293f-5 8.972765f-6 3.7823995f-6 0.00016858063 -6.3829844f-5 -2.7525226f-7 0.00018630018 -5.7466605f-7 8.0061756f-5 -5.4243832f-5 0.00013322335 0.000101166384 -2.569988f-6 1.7071898f-5 -9.4336254f-5 7.86326f-5 1.0714607f-6 -0.00012687568 0.00012527338; 9.752427f-5 -4.3974575f-5 -6.6581604f-5 -1.0627818f-5 -2.7240436f-5 -3.9927207f-5 2.1016638f-5 3.536431f-6 5.065232f-5 0.00011303551 4.191013f-5 8.965037f-5 -0.00020328394 4.9234946f-5 -0.000116116644 3.4937177f-5 2.657299f-5 -1.2267555f-5 -0.00017105127 0.000103749495 -0.000104750114 3.2322678f-5 -8.920638f-5 1.4184763f-5 -5.999597f-5 3.9508566f-5 -5.135275f-5 3.793911f-5 6.03961f-6 0.00014972486 -0.00011998434 0.00020815583; 1.8215434f-5 0.00012200313 1.6943868f-5 2.77974f-5 1.1226332f-5 -1.9277073f-5 4.720047f-5 -1.0280557f-5 -8.389412f-5 9.540035f-5 -1.3365507f-5 8.976622f-5 0.00011280966 -0.00016639948 -1.1843532f-5 0.0001943641 4.4157596f-6 -5.2231306f-5 -0.0001360263 6.928031f-5 -5.9677044f-5 -2.0659033f-6 0.00015281455 -6.500199f-5 -0.00012297847 -1.029445f-5 3.740851f-5 0.00012436682 -8.671118f-5 0.00011508358 0.00012401902 9.2518094f-5; 4.472984f-5 8.10952f-5 -1.1355253f-5 -0.00016094414 4.294831f-5 -5.905569f-5 1.6661235f-5 2.8710503f-5 2.6005295f-5 -5.429565f-6 0.00013668872 3.0462148f-5 -7.1517745f-5 -0.00010090733 7.798974f-5 -3.8624163f-5 -6.0409562f-5 -6.7832894f-5 6.5904336f-5 9.823511f-5 9.1126885f-6 -0.0001860099 -5.2250638f-5 -0.00014020961 9.8793025f-5 -2.1824288f-5 -0.00013201093 0.00015491004 0.00011865592 1.0762648f-5 -3.5757353f-6 0.00013620092; 5.473531f-5 -7.671312f-5 7.128951f-5 -2.730904f-5 -9.310519f-6 3.4478f-5 6.36872f-5 5.679825f-5 -6.333436f-5 -7.01183f-5 -5.7925256f-5 6.568984f-5 -8.0653845f-5 -0.00016529285 -4.716212f-5 -7.5688884f-5 1.4067194f-5 -7.111188f-6 -0.00010336405 -0.00015981478 -7.27257f-5 -0.00017191938 -8.013809f-5 -0.00016004394 3.417455f-5 -6.103672f-5 -7.7855475f-5 4.0499322f-5 -2.867241f-5 -0.00010960639 0.00012915829 7.383914f-5; 4.5214394f-5 -7.0433285f-5 -0.00011821325 8.013893f-5 -4.976203f-5 -4.8106707f-5 8.1997205f-5 0.00012743284 -5.372316f-5 -4.2163585f-5 -1.6445181f-5 8.5437554f-5 -0.000101449645 -3.0047962f-5 -1.639068f-5 0.00011921904 -5.6416055f-5 -0.00017675376 8.524147f-5 -0.00012290214 8.9349f-5 0.000109636116 -0.00017566094 0.00014297433 6.896368f-5 0.00015725572 -1.7422728f-5 -0.00013611472 -0.00016198172 -7.884122f-5 7.768822f-5 1.7409866f-5; -5.359841f-7 6.164917f-5 -9.518823f-5 0.00013171781 -1.8415532f-5 -8.809816f-5 -3.9568207f-5 0.0002536066 0.00013241789 -0.00012571314 -0.00016330625 -0.00010531292 -2.2731705f-5 -3.8527915f-6 7.693935f-5 -2.9186227f-5 -4.4779183f-5 0.00032918912 1.9363595f-5 -3.4221473f-5 -7.65573f-5 -2.2588962f-5 3.0385394f-5 -1.2074952f-5 -7.330127f-5 -0.00020046013 5.1070976f-5 -5.3816802f-5 4.7592068f-5 -0.00010298819 0.00010849015 0.00015951884; -9.693373f-5 -0.00013000688 -0.00012587754 4.735378f-5 3.2186115f-5 3.2204778f-5 8.7086455f-6 -0.00021973232 -0.00017441726 0.00013927721 -7.043499f-5 -0.00014156343 -0.00013560252 0.000108967186 -6.665942f-5 2.8041864f-5 5.9980284f-5 -0.00011682374 -5.1021107f-7 -3.6399713f-5 -0.000118463664 -7.462405f-5 -7.599712f-5 0.00012049323 -0.0001713647 -0.00012564352 -3.7130852f-5 -7.656362f-5 2.0386085f-5 0.00014206827 -0.00014274541 -0.00014260979], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[-0.00014494236 3.8042403f-5 -3.5856083f-5 0.00010662776 -8.2596176f-5 1.1392547f-5 3.2741114f-5 -0.0001070538 3.580976f-5 3.8762464f-5 0.00023721643 -7.228617f-5 -0.00015385395 -7.5215844f-6 9.201995f-5 5.6072244f-5 4.5738252f-5 -1.7930955f-5 -2.5902304f-5 -0.00012620776 -0.00016140178 0.00010546495 0.00024153161 0.00015272872 -9.449422f-5 6.469381f-6 -0.00019181239 -3.174616f-5 -0.000120462886 0.00015835269 -1.5498672f-5 6.364192f-6; -8.686542f-5 0.00011894204 1.0042476f-5 -6.6040026f-5 -0.0001636497 -0.00021018248 0.000105849 -9.480914f-5 -8.089733f-5 -3.883982f-5 -0.0001976878 7.750569f-5 -0.00015728237 0.00015460362 0.0001448883 -6.784166f-5 0.00013997503 -3.26356f-5 1.0850052f-5 8.461288f-6 -6.4083715f-6 -6.7224086f-5 6.527159f-6 0.00011620968 4.217586f-5 -4.4131804f-5 -6.977044f-5 0.00023036718 -1.199927f-5 -6.774964f-5 0.0001551992 -0.000120859164], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most deep learning frameworks, Lux defaults to using Float32. However, in this case we need Float64

julia
const params = ComponentArray(f64(ps))

const nn_model = StatefulLuxLayer(nn, nothing, st)
StatefulLuxLayer{Val{true}()}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),            # 64 parameters
        layer_3 = Dense(32 => 32, cos),           # 1_056 parameters
        layer_4 = Dense(32 => 2),                 # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"];
        position=:lb,
    )

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end

Warmup the loss function

julia
loss(params)
0.0007120310876873147

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob,
    BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback,
    maxiters=1000,
)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [0.00015136769798099908; 0.00018915694090521795; 7.710528734598228e-5; -8.100904233286776e-5; -3.4554857847968315e-5; -6.137268792373937e-5; 9.628864972914156e-6; 5.067023084840313e-5; 0.0001419479813192439; -0.00012489684741008887; 0.00011901706602652535; 0.00010201956320073964; 6.250788283056874e-5; -3.492094765532106e-5; 2.6304271159416428e-5; 9.691579180055896e-6; -5.673728537650946e-6; 0.00015087373321885566; -2.170498919438327e-5; 3.540720717861911e-5; 8.496613008894031e-5; -7.817985897408877e-5; 2.938671968873953e-5; -2.114699054799406e-5; -0.00015802774578314375; -5.6834804126943294e-5; 8.889765012998567e-5; 5.710302502840176e-5; -8.75308342073923e-6; -0.0001283062156289698; -9.810520714383053e-5; -0.00017922109691398598;;], bias = [2.1257874048952633e-16, 1.8243953103128644e-16, 8.629308440987583e-17, -1.3354296529347743e-17, -1.2707123422027693e-17, 1.934546259740816e-17, -4.3776016614930166e-18, 9.128348761236478e-17, 1.6719093669793447e-16, -4.481600781751427e-17, 5.4414174314875407e-17, 1.6626618792646438e-16, 1.3552192519652084e-16, -1.7978147367261283e-17, 5.0789316357772126e-17, 1.7567942993365975e-17, -8.626430926182648e-18, -1.1001538793593143e-16, -2.3627254537949028e-17, 3.5667653004790533e-17, 8.512025901882335e-17, -1.0662595605408612e-16, 3.906159541305963e-17, -1.1746971406482171e-17, -2.1472298844953123e-16, -5.690408114869887e-17, 1.2430137079487941e-16, 9.238887964582106e-17, -4.142416563172998e-18, -1.2679196418130657e-17, 1.4135264375881844e-17, -3.159991008945465e-16]), layer_3 = (weight = [-0.00014669790048559445 -3.1558890763565284e-5 4.6758280082742354e-5 -0.00013737344066086602 -3.8460591307932456e-5 0.00011173731822580952 -1.2780525461686019e-5 0.0002548905473720753 -0.00012666932830528312 7.393180699883403e-6 6.13887902052489e-5 5.287187800220359e-5 7.010204776250661e-5 5.347182072526964e-5 -4.536417632535787e-5 0.00010843489930793018 0.0001147975423425933 -8.388971705629723e-5 -2.025508488642249e-5 -8.015402213842422e-5 7.3704838048621975e-6 -3.5598862590830334e-5 0.00011899876758051331 9.045526589560592e-5 0.00011893827526888342 -7.682474038503157e-5 9.736941924145856e-5 0.00011965380749260048 0.00017070617527296144 5.2679436199259636e-5 -2.2735849919829094e-5 3.852322756370235e-5; 8.878827042192365e-5 -4.299656962846764e-5 4.2912280822776835e-5 1.4402087562331813e-5 -0.00014910641790353633 -4.507057043244172e-5 0.00010820549481562439 6.593450938365426e-5 7.273547065690412e-5 -6.569590938031245e-5 5.244676448481805e-6 -0.00015571813879520935 -4.727425426696475e-5 7.240811773075824e-6 7.930523110102009e-5 4.944607771026047e-6 2.083437333479627e-5 8.796901214648188e-5 1.934953592266042e-5 -4.1245712292007544e-5 -3.060419863774618e-5 -5.7304110906414165e-5 0.0001543191899500329 -6.838953437269578e-5 -0.0001242120054130064 1.1915803715085862e-5 0.00026394982682089124 7.285728473928732e-5 9.029471379020351e-5 -1.4621321734896177e-5 -0.0001270403156567858 6.013839423669287e-5; 2.94158450853322e-5 -1.7572280092127504e-5 -0.0001659277152501136 0.00012469316762889727 2.394981648849272e-5 3.225967858470255e-5 0.0001426829728299709 0.00019644193863514186 0.00012231662159336578 0.00013731497502902728 -2.8597098469446922e-5 3.686951068829432e-5 1.628438413917753e-5 -3.62731227940466e-5 -5.459288903493912e-5 8.061700983689668e-5 0.00011920799780655227 -0.0001479276218449919 -2.1218709923960214e-5 2.357587046590112e-5 -4.458612992294194e-5 -6.260115723107916e-5 -0.00017414791753963128 0.0001347035756336479 1.0737703371140442e-5 2.279950577958275e-5 -9.460177924466436e-5 4.723048160685477e-5 7.880934182339851e-5 0.00019094889622664216 -0.00010571234842046417 -7.291871841768914e-7; 0.0002438512074501449 -9.719161058579403e-5 0.00014937419295043977 0.0001070820635856434 -5.3316065497103566e-5 0.00018166922931292534 -3.742390068630935e-5 -4.0798264273083246e-5 -8.677826004836429e-5 2.3221155096195872e-5 -8.814825738349251e-5 6.729681448510845e-5 -0.000125088839887516 0.00012272714778979916 0.00015169785732602415 1.4444134684407472e-5 -1.6189240306022223e-5 0.0001357771743993472 -9.104360102476823e-5 -0.00010568224566789669 0.00011837176683790262 -4.446480848977059e-5 -7.638005805468486e-5 -2.552053771190151e-5 -9.195757044049023e-5 0.00016771187623695942 9.557257548173122e-5 -9.240951654768077e-5 -1.2782843406358867e-5 -8.761798559257236e-5 2.015448629996323e-5 8.286966298205261e-5; 0.00013516095941693896 -8.261520178817756e-6 8.427991861052331e-5 -0.00015172804114601128 -8.400831711670842e-5 7.267214187127491e-6 3.57613303097292e-5 -6.102977747237255e-7 0.0001510424540358914 -0.00020706909117481265 0.00014318251351167863 -4.611575195641757e-5 -0.0001174549483946199 -2.628314496152096e-5 9.380508572602189e-5 6.117072052828515e-5 1.7557828414542722e-5 0.00015907741044455585 6.684921982040175e-5 0.00014917170346987905 -3.149176472990328e-5 0.00010615076325874294 -0.00012780574024847237 -0.00010665424042573653 -0.00013812226234196317 -8.406998085748872e-5 1.0637285868401872e-5 -0.00011384311937919354 -5.6317979669724483e-5 -0.00011035544002853861 3.585772219620148e-5 0.000188013515526367; -0.00013861294888487446 -0.00013269692595643942 -5.0475185216095947e-5 4.1308702694328016e-5 -3.884904853517708e-6 0.00011694093240362656 1.853919542623184e-5 1.6459372037165174e-5 -0.00023897127999520363 -8.60392001952659e-5 -2.5171603792152706e-5 -0.00011072778312085889 -2.9837413475657144e-5 0.00012050753406652247 -4.601794813356381e-5 7.354557868962141e-5 2.8434666947535813e-5 -6.705967334018943e-6 2.933999434458533e-5 -3.933962542983972e-6 0.00014102995459393754 -1.0596317278436532e-5 -5.23830504419004e-5 1.9517227952168863e-7 -5.7023849021096984e-5 -1.1092146505002841e-5 9.65871089716149e-6 -8.000006806400554e-5 7.658850057919475e-5 1.850936218102862e-5 3.528368228285792e-5 0.00015348031132886072; -6.224870739065912e-5 6.971096300725167e-5 6.86891129680188e-5 5.673774868224016e-5 -9.59102340766921e-5 -5.121353378892883e-6 7.546656562222767e-5 -0.0001644222516825192 -5.928024765130764e-5 4.0098608596917445e-5 0.00013757357229219483 9.377507959466391e-5 0.00012281793025063715 -4.46053723772283e-5 -1.6030782344424152e-5 -8.922084233749059e-5 0.000130004233515101 -5.486651351935091e-5 -4.107217099775918e-5 0.0001748244671125204 9.657503514931742e-6 -8.314931764293234e-5 -1.2673235314687805e-5 3.765684451533125e-5 7.670956111718905e-5 9.78259471684921e-5 -9.684922551063701e-5 -1.9087952377997038e-5 5.231702235501857e-5 0.00013630709000603825 -6.77961012025176e-5 7.149221264570112e-6; 8.515053193734788e-5 -2.0906718633106598e-6 -2.181989357229728e-8 -7.451038548947115e-6 2.9838491151554412e-5 -1.3920683180526435e-5 3.492935864089791e-5 -4.37102535430082e-5 5.89755909244792e-5 6.864652821279882e-5 -8.776111848032515e-6 -4.603537311011179e-5 7.284909950287819e-5 6.91417371639776e-5 0.0002326441681147347 -2.7019603315761427e-5 -3.61552265745328e-5 -0.00010111282610848718 -8.126728064196382e-5 -0.0001544015613249709 1.658886319676225e-5 0.00019015056748360203 -0.00011920708495706517 0.00010600571838624578 -8.64311296241561e-5 -8.967601754518419e-5 -0.00011112184440532888 -3.548176757574275e-5 -8.22989022922908e-5 0.00022909302058581082 -0.00020745570511605777 5.0288686885711074e-5; -1.6718299090536756e-5 1.4525853433327537e-5 0.00016850816769267446 5.7225735648283665e-5 -0.00011187149197819485 -1.112016093127351e-5 1.1279362982646148e-5 -4.120924336045865e-5 4.9137206363701674e-5 9.077254958658195e-5 9.855288385853862e-5 5.41497081515646e-5 0.0001684623291597057 2.1037563559116503e-5 5.4850666632160414e-5 -0.00023484989492014786 -5.991842723846599e-6 0.00011450143370286054 3.6054441582860356e-5 -0.00025574945743453534 5.6743718008259025e-5 1.827752726360231e-5 3.213152172368187e-5 5.9673880744708647e-5 0.0002544258766222968 9.878189462444836e-5 -2.636547310128391e-5 -0.00018813006886073296 -7.313780792780959e-5 -0.0001945158858204012 -1.2138347525867719e-5 -5.0198354298741105e-5; 1.4552585934866196e-5 -4.714797937631968e-5 -4.049891827017138e-5 0.00012129181091003979 3.4230488790487654e-5 5.511285821703216e-5 0.00011422437496512187 3.7859158009811884e-5 -0.00014463656730222724 -6.35143692765308e-5 7.517644230537898e-5 0.00012221953916159348 1.5788610152955614e-5 3.828723170008488e-5 -0.00020482535505871307 0.00020159159289468998 5.012050632646903e-6 3.0383930106468028e-5 -9.44530580003654e-5 9.15191962572837e-5 0.00013424492176418112 -9.110134446964668e-5 1.4682998380155582e-5 -0.00010873144087818464 -2.2535231937950237e-5 -8.357415531500284e-5 0.00011062875678325917 -0.0001307386879037977 -2.549359624392926e-6 -0.00011776607020729391 -2.8509076901540272e-5 3.959160353551292e-5; 0.00022516679812253194 0.00012007637335481476 0.00012265118923537248 0.00010675891164723494 2.85291868975541e-5 -5.9195866265462735e-5 -0.00013068083287648047 4.96595209595879e-5 6.156040603691662e-6 5.784722518440379e-5 -0.00011584575192970871 -2.9188840267271824e-5 0.00014648942371864028 -0.00012752382397151867 0.0002473055580820586 0.00012657557883799885 0.00013340490906999333 -0.0002123252700359305 8.691842297126349e-5 7.477482788531718e-5 1.5716316488437074e-5 7.825535496962977e-5 0.00013720324998290023 8.257736837991326e-5 0.00011741594855703732 -0.00014724510335795311 9.464646493583403e-5 7.500272542970525e-5 -0.00012540893592371725 -6.652242816487431e-5 -2.7298028067125642e-5 0.00017513385021156666; -0.000101101950139165 0.00019497384913044134 2.0356708983992516e-5 -9.025244232256939e-5 -0.00019360990133359332 6.187918755235938e-5 9.369586608149769e-5 0.00010975781172524019 -0.00014757756595721948 -6.367513721565431e-5 -4.958486694599267e-5 1.224052469401951e-5 8.114492840060052e-6 -5.81448928358895e-5 7.550135908891195e-5 -5.736499021518321e-5 -0.00010391199045338945 6.74672539672962e-5 -3.62128319182832e-5 0.00017369877444357377 7.5649170240944324e-6 -7.089643783624342e-5 -0.00010272558008077801 0.00010812384992382227 -3.932459164283544e-5 0.00018973274403028963 -0.0001550758186284359 -7.40167177071718e-5 -2.6139425034535096e-5 -0.00012526875231637304 -1.98236810036602e-5 0.00010637456419423177; -3.382216289468233e-5 -6.037114997455698e-5 -6.349284501924667e-5 3.912042540901799e-5 -2.833448104797054e-5 4.111136388088044e-5 -2.3924081123589132e-5 4.851024152299762e-5 8.80828216584288e-5 -0.00011494405169218125 5.78250097721691e-5 -6.402087581522223e-5 -9.529874057920249e-5 -0.0001254726388462971 1.9830661240810125e-5 0.00016823671341695213 6.122613433104326e-5 -7.412921093349472e-5 0.00010079885270966544 -1.7592410021500012e-5 -5.768586686757852e-5 6.241069134001466e-7 -6.742813400637023e-5 1.220357318812377e-5 -0.00016545920641805126 -2.156857533849855e-5 3.765240637290778e-5 5.21657772158125e-5 0.00015319955002338016 -2.118227110189647e-5 -3.9050429567614925e-5 2.422726231158319e-5; 0.00013146245048049348 4.1115316816301534e-5 -9.779029194829092e-5 -0.00010754111561323607 -3.4915839097040034e-5 -0.00013121847867364857 -2.8629369881628384e-5 -4.188916416819905e-5 -3.3929528472770787e-5 -0.00012430447084695 -5.21103045754262e-6 9.699346712226578e-5 0.0003122243332829804 0.0002528306912784034 -8.365775564084375e-5 0.00015666655683094143 1.5134921267644722e-5 -6.465523172734782e-5 0.0002301560279367998 3.7708013398956207e-6 2.643105009996654e-5 8.342886296235903e-5 1.8307217869049878e-5 -0.00012355369388456047 -8.957416419504869e-5 0.00015540780161176992 -0.00027021843114267307 7.60854353932216e-5 -7.390041065037067e-6 9.711814793193748e-5 8.77708779801115e-5 9.536755252995221e-5; -4.79765344696798e-6 0.00017457393073438396 -2.711352794993073e-5 1.373093687018036e-5 -4.33702382460776e-5 -7.308883982748563e-5 5.891885189068711e-5 4.053523184577179e-5 -4.4318884875774974e-5 6.78018648150284e-5 6.825163714169173e-7 -5.275602711751987e-5 -3.331507946402608e-5 -1.2188043229778226e-5 -8.443608862851546e-5 5.404322541746518e-5 -1.1304951159197608e-5 0.00016811320051512586 0.0001625468601096349 7.111482659549646e-5 1.8925024732217385e-5 -2.9761398082863194e-5 -1.572162841601099e-5 8.71579493368903e-5 -2.8222911397261858e-5 -0.0001383151513900685 -2.934703047774628e-5 9.753108565657568e-5 6.148382340054235e-5 -0.00011644135997901679 -0.00022698825091627348 -2.6880686627757816e-6; 1.4361798886472112e-5 8.368529949753504e-5 -6.221768865411282e-5 0.00015794059424611629 6.762888358157936e-5 0.00016877241420975767 3.975076822048859e-5 -0.000150085659814562 1.0774943686061693e-5 -4.4112774165608864e-5 1.7218577223917485e-5 7.52926570125566e-5 -9.23594635683403e-5 -0.0002059789986820252 -3.317405583270018e-5 9.23199865742927e-6 0.00010322751756708012 8.451707242007297e-5 -0.00013205033622265918 -0.00017527989996646656 -2.7656868158868665e-5 0.00010319706040850728 -5.4316031426918996e-5 9.58835443011193e-5 8.86370397633587e-5 2.1111134511971604e-5 -7.96055466058023e-5 -0.00011641890899094871 4.7570294438133624e-5 0.00010940156050442632 1.5451061582513064e-5 -0.00015858508425348487; -0.00010097469863926616 -4.7063432489317356e-5 4.349207732498642e-5 -1.5210027654615944e-5 0.0002097244075548198 6.095570129178895e-6 -0.0001158394873209931 3.913472458858148e-5 1.8963365378067813e-5 0.00013372070019154495 3.986118166262942e-5 -0.00010023475557777253 -6.8663014326773925e-6 4.639322354247011e-5 0.0002550730502980126 -7.79327958889939e-5 -5.905316289958356e-5 0.0001014378728859351 0.00013414346243276434 5.095579466338418e-5 3.30395966875176e-5 -7.667822434545507e-5 3.080710698629709e-6 0.00011646511914928619 -1.6552049842222326e-5 -6.104533464626325e-5 8.030690581559134e-5 -2.554226672604597e-5 -0.00017780881104503574 -4.2429669761121024e-5 -2.1380238890784174e-5 -0.00011474165445638747; -7.543771973006392e-7 0.00012716380922469885 4.501187038370548e-5 -1.5218597695974813e-5 -5.910512948064211e-5 0.00011544738561026791 -9.404130935324483e-5 6.051362186083555e-5 -2.2957535407078933e-5 -1.1067132732173198e-5 -0.00014319239664943372 0.0001140199664176476 0.00013570367343602747 1.3339515021314547e-5 5.81856454557755e-6 -2.190982297649735e-5 -6.663664783006664e-5 0.00014070883378266332 -0.0001740244271106098 8.280205177725091e-7 5.708703872626846e-5 -0.00011948053827397378 -6.614035476120281e-5 -0.00010017831766633428 2.0176516841129263e-5 -2.3317135060252465e-5 -8.05825194449141e-5 2.2133561968339456e-6 -7.942923650732391e-5 -4.3202232632360726e-5 -5.700765643738833e-5 -1.580051605301968e-5; -7.243045218763034e-5 1.2349799950549009e-6 3.294624513238696e-5 -7.598755144989058e-5 2.082156026027898e-5 -4.71463103034541e-5 2.2864920187762288e-5 -7.06021713862209e-5 -5.047517229627021e-6 0.00015961542122988522 -9.139612337583253e-5 3.1291954305397904e-5 -0.00012493664361080227 -7.474454140300129e-5 -1.3717210445415593e-5 0.00016584160367948684 -7.42362939357783e-5 9.143314274717955e-5 -4.6033427120514695e-5 -3.5464341986321014e-5 9.095079041309744e-5 -2.297207391655502e-5 8.097198189226695e-5 -0.00011406693045990775 0.00010990093470775532 -0.00022360143737876892 3.0630900388304735e-6 -0.00015064935395037857 -1.6547713906995086e-6 0.00013214315150855636 -9.066131847594216e-6 -6.23491545278451e-5; 0.0001649027920859735 -4.538511952772653e-5 1.0165689992617706e-5 0.0001105547574763928 -9.906167419899486e-5 8.87853760571019e-5 -1.007699997785769e-5 -7.185374001778841e-5 -7.003111263542914e-5 2.331305163466195e-5 0.00016216474737234648 -0.00012518051394067603 -8.927109470164536e-5 -6.827979691365963e-5 -0.00012155988101654 3.168809539715491e-5 2.8989160774696717e-5 -3.491624615775805e-5 -4.277658685964787e-5 0.00016101788182936385 -7.769036042077476e-5 -5.3587051887689675e-5 -4.092383337479325e-5 -9.978210314216347e-5 -5.219149411345206e-5 -8.386681897604993e-5 0.00012216281297810388 6.246459211183559e-5 2.680073826126952e-5 3.412534110907786e-5 -0.00012686287541214302 2.0223954845820434e-6; 1.619757539863585e-5 -0.00010714561859226106 -1.3116842932434959e-6 -3.723184045829242e-5 0.00010012895881356213 6.979624821961929e-5 -5.255243514273602e-5 2.7945742197165066e-5 7.923184505292121e-6 2.507269208998979e-5 3.8545115323898764e-5 0.00014385693292988435 -6.468241691199083e-5 0.00014934972795582844 0.00014247399166617163 0.00019422715058009784 0.00010170680842986054 -1.0066980855682369e-5 1.3564743760632151e-5 -2.290277507871391e-5 2.0828904319028374e-5 8.311087230377539e-5 -3.5958059992332633e-6 -7.363039686186581e-5 6.590670317798419e-5 8.978929189889849e-5 0.00013432526838424045 0.00010064066236064921 0.00014315097586642666 -4.1756110938340835e-5 -3.2518580617214554e-5 1.4356111471500153e-5; -3.358963542166878e-5 -3.055845692768839e-5 4.975943491547135e-5 9.952280635046606e-6 6.0152438237423966e-5 9.270114133982612e-6 8.567295415662164e-5 6.027194584123655e-5 -9.057396411561108e-5 0.00016618808265877384 0.00010147519176927863 6.691832211902007e-5 -5.4612801904000225e-5 -0.00013784514041515167 -0.00027719666715609 7.053080597063461e-5 0.00019410727080060102 2.0788174225452724e-5 -0.00015779882843969468 0.0001645957393349138 5.1563155721963496e-5 -4.494676861111312e-5 -0.00017503091541482224 -5.384184871115798e-5 -7.678210269558085e-5 -8.630253316266617e-5 -3.037281814511222e-5 3.257355390917353e-5 3.67318328072776e-5 8.884792012497255e-5 4.307060706298461e-5 0.00015332810406208923; 9.658841980069833e-5 -6.46153396888813e-5 -8.922973695064185e-5 -2.0129237515215462e-6 -5.667932619530845e-6 9.578308770812805e-5 0.00011017042077526168 3.2030415858912515e-5 -3.146882347266261e-5 -2.70478114600645e-6 0.00018768691614184398 6.614431429568505e-5 -0.0001646649228138128 -0.00015548749554056943 -0.00010138303088808708 1.256995317760341e-5 1.5305399464864256e-5 1.575948288462081e-5 -5.2688153410258686e-5 7.728386371468295e-6 0.00012465749266937168 4.141569177534213e-5 -3.267430776816135e-5 3.0483248955874717e-5 -0.000107505087108277 0.00010057734714735051 5.0390199866684025e-5 -8.587607346328224e-5 -0.0001924548753964929 -8.921623085028195e-6 9.162786107424466e-5 -6.048685947518741e-5; -1.874633092953619e-5 9.835399138394796e-5 9.90060627053454e-5 3.8701933462716855e-7 -1.1436182695892686e-6 -1.2827411730102762e-6 0.00015052427367195388 -2.9963951269702367e-5 0.0001298131439078493 -2.691820630858097e-5 -2.0967402483763455e-5 9.748077460685152e-5 -0.00020526010015160128 -1.3555536347803794e-5 -3.4013363212126975e-6 1.5978200431027395e-5 -0.00015057585081663079 -5.7212487113394516e-5 0.00018474847198817761 4.15481273783821e-5 0.00016613492804467235 3.539209243612908e-5 -0.0001842582231212421 0.00015946254021841918 -2.5822285201832526e-5 2.6819801100690215e-5 0.00011018205729200962 2.0563885450290803e-5 -0.00013391568321380863 0.00010466716518280144 -9.42205104685444e-5 -0.00010792727228833133; -9.811858493687262e-5 -8.236004672564807e-5 -2.8601801240265853e-5 -1.1466669128831765e-5 3.1249734966870125e-5 -3.81660438861302e-5 -3.220804595822074e-5 3.9022556812589126e-5 9.607465013776421e-5 2.414458773363095e-5 -0.00013088316794802657 7.192608878630156e-5 -4.4661107518490034e-7 -5.554295152160925e-5 8.974106059206145e-6 3.7837409863573184e-6 0.00016858197610508118 -6.382850282445221e-5 -2.739107926642457e-7 0.0001863015231365279 -5.733245770997558e-7 8.006309777646455e-5 -5.4242490513395635e-5 0.00013322469290960026 0.00010116772588024148 -2.5686465860499195e-6 1.7073239142508768e-5 -9.433491235450585e-5 7.863393962997408e-5 1.0728021706969188e-6 -0.00012687433358133797 0.00012527471968585543; 9.752428609185841e-5 -4.3974561277944745e-5 -6.658159095554617e-5 -1.0627804341719043e-5 -2.7240422652037815e-5 -3.9927193439733245e-5 2.101665136119817e-5 3.536444226985325e-6 5.0652334310645155e-5 0.00011303552177974015 4.191014390794554e-5 8.965038300498383e-5 -0.00020328392697724514 4.9234959577508276e-5 -0.00011611663053504194 3.4937189910626044e-5 2.6573003695877274e-5 -1.2267541434492543e-5 -0.0001710512601234298 0.00010374950811506248 -0.00010475010044633415 3.232269095729103e-5 -8.920637033174749e-5 1.4184776363028217e-5 -5.999595800990249e-5 3.950857949216849e-5 -5.1352735182670905e-5 3.7939122692979864e-5 6.039623302560071e-6 0.00014972487274331035 -0.00011998432587592868 0.00020815584734787222; 1.8217795949589325e-5 0.00012200549506725387 1.6946229588064326e-5 2.7799761009226033e-5 1.122869387385659e-5 -1.927471159520767e-5 4.7202830027384504e-5 -1.0278195551071538e-5 -8.389175811405373e-5 9.54027095567928e-5 -1.3363145190822778e-5 8.976857905421148e-5 0.00011281202430749557 -0.00016639711568219135 -1.184117038151918e-5 0.0001943664578712227 4.418121060592464e-6 -5.222894426058329e-5 -0.0001360239362120018 6.928266928211307e-5 -5.967468270770284e-5 -2.0635418613544303e-6 0.0001528169127405297 -6.499962502944701e-5 -0.00012297610694165799 -1.0292088082641633e-5 3.741087173606207e-5 0.00012436918448569545 -8.670881972735205e-5 0.00011508594296085855 0.00012402137915981738 9.252045536324309e-5; 4.4730096123256693e-5 8.109545233224499e-5 -1.1354997878642754e-5 -0.00016094388338473397 4.2948566987489276e-5 -5.905543294005777e-5 1.6661490734860326e-5 2.8710758487305422e-5 2.600555017039293e-5 -5.429309750016193e-6 0.00013668897081430234 3.046240344616314e-5 -7.151748941481796e-5 -0.00010090707583089086 7.798999358698072e-5 -3.862390731313916e-5 -6.040930675311663e-5 -6.783263888877723e-5 6.590459161003399e-5 9.82353674763189e-5 9.112943916134425e-6 -0.00018600964467708721 -5.225038257862166e-5 -0.00014020935414095132 9.879328063021488e-5 -2.182403222917789e-5 -0.000132010674134053 0.00015491029698665507 0.0001186561719796928 1.0762903729355009e-5 -3.57547987818167e-6 0.00013620117606393073; 5.4731175782427813e-5 -7.671725309169165e-5 7.12853799765129e-5 -2.731317347333308e-5 -9.31465167505003e-6 3.447386568459761e-5 6.368306655949496e-5 5.67941171308133e-5 -6.333849144118405e-5 -7.012243424958732e-5 -5.792938811771947e-5 6.56857084091944e-5 -8.065797721919506e-5 -0.0001652969802538798 -4.716625319397412e-5 -7.569301653839134e-5 1.4063061075464767e-5 -7.115320682140578e-6 -0.00010336818361374915 -0.00015981891176616107 -7.272983186830218e-5 -0.00017192351499524503 -8.014222096371777e-5 -0.00016004807532717845 3.417041823434798e-5 -6.104085316601203e-5 -7.785960754098661e-5 4.0495189909599844e-5 -2.8676542411087546e-5 -0.00010961052596519768 0.00012915415588814707 7.383500649179522e-5; 4.5213640065251976e-5 -7.04340391295529e-5 -0.00011821400131719309 8.013817840566752e-5 -4.9762785251170906e-5 -4.810746123819787e-5 8.199645070437078e-5 0.00012743208479679916 -5.3723914801755534e-5 -4.216433900562934e-5 -1.6445935229220577e-5 8.543679995074855e-5 -0.00010145039870092817 -3.0048716338735583e-5 -1.6391434668714816e-5 0.00011921828534997735 -5.6416808560409446e-5 -0.00017675451546363247 8.524071289304388e-5 -0.00012290289761420577 8.934824562476856e-5 0.0001096353616829391 -0.00017566169573381789 0.00014297357391644016 6.896292596668346e-5 0.0001572549690407123 -1.7423481962560487e-5 -0.00013611546911686847 -0.00016198247341361164 -7.884197567744049e-5 7.768746853418616e-5 1.740911192048391e-5; -5.356356727659215e-7 6.16495154641058e-5 -9.518788365353941e-5 0.00013171815871787138 -1.8415183968004407e-5 -8.80978122044829e-5 -3.956785885789476e-5 0.00025360693607539805 0.0001324182368075913 -0.00012571278866785804 -0.000163305900878601 -0.00010531257512453175 -2.273135660330387e-5 -3.852443024202355e-6 7.693970028371497e-5 -2.9185878611728995e-5 -4.477883421755439e-5 0.0003291894673562112 1.9363943556113614e-5 -3.422112504309102e-5 -7.655695017133964e-5 -2.258861322885027e-5 3.038574241104017e-5 -1.2074603455816878e-5 -7.330092275920651e-5 -0.0002004597780800898 5.1071324468961505e-5 -5.38164539652419e-5 4.7592416285210785e-5 -0.0001029878411831823 0.0001084905007697161 0.00015951918530803693; -9.693986699330075e-5 -0.00013001302336685045 -0.00012588368459021714 4.734764028923668e-5 3.217997531988602e-5 3.219863815117153e-5 8.702505552498204e-6 -0.00021973845919477536 -0.00017442340227189477 0.00013927107015278208 -7.044112767672431e-5 -0.00014156957236466082 -0.0001356086567777626 0.00010896104599465568 -6.666556051117525e-5 2.803572446806082e-5 5.997414432000083e-5 -0.00011682988311276253 -5.163510233998684e-7 -3.64058532379572e-5 -0.0001184698039234967 -7.463018751353032e-5 -7.600326257874379e-5 0.00012048708690213767 -0.00017137084701443063 -0.00012564966068951991 -3.71369923907204e-5 -7.656976136262173e-5 2.0379945495512347e-5 0.00014206212749358457 -0.00014275155167908646 -0.00014261592782911325], bias = [3.245664437691989e-9, 1.1930197022113453e-9, 2.4189941542785593e-9, 2.0573121843940033e-9, 1.3687921234287899e-10, -8.195086106625228e-10, 1.8757509314164613e-9, -3.640122909317634e-11, 9.47469130237728e-10, 3.0236114737814143e-10, 5.275728524493401e-9, -1.1075543174570724e-9, -6.560905991028265e-10, 2.4855755779555195e-9, 4.099538477618437e-10, -7.287648631783646e-11, 9.97474633184967e-10, -1.4052831581038966e-9, -1.4664559650665583e-9, -1.0974406478107529e-9, 4.74619138713558e-9, 8.742800907362738e-10, -3.2498121544016827e-10, 1.4744408419212435e-9, 1.3414692080444087e-9, 1.323183067582592e-11, 2.361483933303736e-9, 2.5538525826657966e-10, -4.132585129335962e-9, -7.539840069277455e-10, 3.484270065425413e-10, -6.139957116124385e-9]), layer_4 = (weight = [-0.000837881693334656 -0.0006548971446485 -0.0007287955300005662 -0.0005863117272002276 -0.0007755357545227088 -0.000681547017301238 -0.0006601983862515657 -0.0007999933771152378 -0.0006571297989366752 -0.0006541771127779512 -0.00045572259157325835 -0.0007652257239041228 -0.0008467935198798082 -0.0007004610231333457 -0.0006009196224776035 -0.000636867335005367 -0.0006472013046803204 -0.0007108704903702654 -0.0007188418342396731 -0.0008191473075757197 -0.0008543408316264216 -0.0005874746152364949 -0.0004514079645934814 -0.000540210808403915 -0.0007874337594896016 -0.0006864701979893265 -0.000884751836412073 -0.0007246857373129887 -0.0008134020699223321 -0.0005345868769707373 -0.0007084382480377148 -0.0006865745641549241; 0.00015430504421103602 0.000360112578254681 0.0002512129758152687 0.00017513048889374156 7.752084183830167e-5 3.098806122443097e-5 0.0003470195200349605 0.00014636140544521312 0.0001602732115572013 0.0002023307264577043 4.3482555162022874e-5 0.0003186762263856637 8.388817458066997e-5 0.000395774118934974 0.00038605884197263907 0.00017332888374989902 0.0003811455705808572 0.00020853493188842363 0.00025202058056521037 0.00024963182430741133 0.00023476199064614714 0.00017394645417077856 0.00024769770364671615 0.000357380210938993 0.00028334639076182385 0.0001970387412754288 0.00017140006191561352 0.0004715377240991287 0.00022917113855150354 0.00017342089871689623 0.0003963697473441241 0.00012031109647961482], bias = [-0.0006929395788713048, 0.00024117054559260676]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
    compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(
        ax,
        tsteps,
        waveform_nn_trained;
        marker=:circle,
        alpha=0.5,
        strokewidth=2,
        markersize=12,
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb,
    )

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.12.4
Commit 01a2eadb047 (2026-01-06 16:56 UTC)
Build Info:
  Official https://julialang.org release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz
  WORD_SIZE: 64
  LLVM: libLLVM-18.1.7 (ORCJIT, icelake-server)
  GC: Built with stock GC
Threads: 4 default, 1 interactive, 4 GC (on 4 virtual cores)
Environment:
  JULIA_DEBUG = Literate
  LD_LIBRARY_PATH = 
  JULIA_NUM_THREADS = 4
  JULIA_CPU_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0

This page was generated using Literate.jl.