Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux, ComponentArrays, LineSearches, OrdinaryDiffEqLowOrderRK, Optimization,
      OptimizationOptimJL, Printf, Random, SciMLSensitivity
using CairoMakie

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end
one2two (generic function with 1 method)

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end
soln2orbit (generic function with 2 methods)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end
d_dt (generic function with 1 method)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0)<1e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio1 "mass_ratio must be <= 1"
    @assert mass_ratio0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end
compute_waveform (generic function with 2 methods)

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defiing a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32),
    Dense(32 => 2; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32))
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-2.6929512f-5; 0.00018981376; 8.5210355f-5; -1.6737322f-5; 0.00010250499; 8.728338f-5; 8.984874f-5; 1.6952994f-5; 3.2424089f-6; -1.9008969f-5; 8.083607f-5; -7.8833116f-5; 2.4738581f-5; -9.487601f-5; -3.355939f-5; 0.00011032536; 4.2071013f-5; -7.079802f-5; -1.6894368f-6; 1.7476523f-5; 0.00014546733; -0.00011259165; 5.727428f-5; -7.214386f-5; 6.5670174f-5; -5.3300333f-5; -1.5697879f-5; 9.638904f-5; -7.4446165f-5; -6.256097f-5; 7.22826f-5; -1.1016752f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[2.7181211f-5 -7.4973395f-5 -1.973043f-6 7.515609f-5 -0.00013121242 5.2939624f-5 6.4467175f-5 -0.00019519463 -0.00014116873 -2.4643052f-5 5.901564f-8 -5.9485577f-5 -1.3828867f-5 -8.77186f-5 -6.988924f-6 0.00015344813 -1.4318247f-5 -0.00013678607 -8.2587874f-5 -2.2887627f-5 -3.697211f-5 8.96225f-6 -4.0549614f-5 -0.00011447568 -0.00010526265 9.55979f-6 -0.0001383478 -0.000104105595 -5.7607765f-5 2.1888844f-5 9.228808f-5 0.00013523705; -4.5802695f-5 2.8404344f-5 0.00019994812 0.00016648772 -1.4791882f-5 -7.729306f-5 0.00010033749 -6.8065434f-5 -4.1996853f-5 0.00015107609 -3.5714755f-5 7.952848f-5 6.254838f-5 -0.00022857559 8.098557f-5 -6.942336f-5 -9.511261f-5 0.00015341923 -0.00011749918 -3.920416f-5 -0.00021130672 2.6923966f-5 -5.8245823f-5 -5.1851457f-5 2.197893f-5 -2.0284366f-5 9.169845f-5 -0.000112511996 9.7995835f-5 -0.00016347569 -2.8283872f-5 -0.0001224252; 8.981815f-5 1.1647313f-5 -4.2325744f-5 3.640947f-5 3.1134352f-6 -3.6182275f-5 9.1049704f-5 -3.7960624f-5 5.2886953f-5 -2.5289239f-5 -4.3843087f-5 -7.8027915f-5 0.00014565542 3.5431454f-5 2.2586766f-5 -0.00011937564 0.00020653951 0.00015257373 -9.902398f-5 -0.00014347163 -8.9916975f-6 9.59606f-5 1.9347031f-5 1.39988415f-5 -0.00017774654 0.00011883553 8.640436f-5 -4.5346547f-5 -4.4463464f-5 -1.6771806f-5 2.2906326f-5 -6.970751f-5; -0.00017271371 0.000101906146 -9.034281f-5 0.00016842768 -6.212892f-5 -3.5693873f-5 2.740806f-5 8.716583f-5 1.8545687f-5 -0.00011745586 -9.3403396f-5 -1.885468f-5 -8.439638f-5 -3.83929f-6 0.00029246564 1.9709129f-5 -5.9152742f-5 -0.000101080455 -7.433823f-5 4.12363f-5 -6.96251f-5 -8.463439f-5 4.2963387f-5 -0.00014974386 3.394148f-5 2.2386386f-5 -4.087001f-5 0.00018560518 -1.730069f-5 4.3959506f-5 -0.00014297022 -9.257985f-6; 2.4927427f-5 -3.924995f-6 4.148369f-5 -0.000102267564 -1.7078826f-5 9.164007f-5 -1.8445315f-5 -0.00011121433 0.00010440225 0.00019963717 -0.00014236709 -8.997528f-5 -5.1412558f-6 -7.93774f-5 -2.7266971f-5 0.00013993598 2.1801123f-5 7.783963f-6 -0.00027939884 9.784755f-5 1.0055836f-6 -0.0002263965 2.2706854f-5 -0.00017380943 -0.00011278892 2.2710565f-5 -5.7923266f-5 1.6200564f-5 -7.662968f-5 2.5161082f-5 -4.531442f-5 -4.4020468f-5; -3.225065f-5 -0.00010047525 -3.2630614f-5 0.00014604186 -2.5247682f-5 -0.00015280322 -3.198595f-6 4.8905607f-5 -4.505739f-5 3.7675677f-6 -2.0586142f-5 3.8564904f-5 -1.9345023f-5 -0.00013770715 -3.9292205f-5 -0.0001358633 3.6161397f-5 -0.00027004414 -3.5386995f-5 -3.9902272f-5 3.451185f-5 2.0981672f-6 -0.00015331892 1.6082331f-5 3.1313724f-5 -3.107408f-5 6.919795f-5 8.18786f-5 -3.628433f-5 -3.765978f-5 -2.7009388f-5 -0.00012547988; -9.708601f-5 -0.00011882434 5.277656f-5 2.8887776f-5 -9.285283f-5 -2.3123777f-5 -6.291155f-5 -7.0890856f-5 -7.889724f-5 -5.5500957f-5 -1.7185183f-5 3.176948f-5 2.5291025f-5 3.5635174f-5 3.259072f-5 -5.3056683f-5 7.832824f-5 0.00014366276 6.563533f-5 -6.725163f-5 -0.00016290555 0.00012991267 8.2586535f-5 -5.3265652f-5 0.00013109161 -0.00014610615 0.00016742406 3.905198f-6 -5.835714f-5 0.00015458709 -4.096706f-5 -1.9494944f-5; 8.582904f-5 -4.0230498f-5 -1.4473361f-5 -0.00018686937 2.7649585f-6 -8.323275f-5 8.213123f-5 -6.170464f-6 5.654228f-5 -4.904182f-5 7.4876225f-5 -6.73908f-5 -4.4364348f-5 2.4766678f-5 4.1089108f-5 2.302751f-5 0.000110853616 9.245142f-5 2.3892815f-5 0.000109747234 2.2043005f-6 -2.9606985f-5 8.257015f-5 3.3684635f-5 0.0002515062 0.0001927695 8.9528025f-5 -3.932429f-5 0.00011500205 1.01400865f-5 2.946345f-5 0.00013299403; -0.00011059399 9.40908f-5 0.00012769004 -0.00010614186 8.4784224f-5 -0.000101541016 -2.866452f-5 5.132033f-5 0.00010064976 1.2722947f-5 9.612458f-5 -0.00018997937 7.47885f-5 -8.583217f-5 -9.334893f-5 6.596327f-5 7.36742f-5 4.596379f-5 6.861756f-5 -5.155078f-5 -4.176891f-5 6.063279f-5 -7.0312664f-5 0.00011952867 1.976988f-6 0.00019986027 9.450506f-5 2.0750951f-5 -8.452549f-5 3.9159306f-5 -0.00013408513 -5.9972645f-5; -6.38442f-5 -7.250719f-6 -1.2588394f-5 -0.00023265494 -9.788938f-5 -5.779192f-5 -0.0001358655 4.6041278f-5 -0.00015010996 6.72954f-5 -6.4703134f-5 5.6080764f-5 -4.7294525f-5 -2.81178f-5 5.120316f-5 0.00024353055 3.8225167f-5 -0.00011746057 -0.00016014365 -0.00011821678 -9.724365f-5 7.3189844f-6 -8.169283f-5 -8.2156876f-5 3.356244f-5 0.00011676004 0.00014872168 -8.734157f-6 -2.8514989f-5 0.00018830475 -6.7142304f-5 -5.561995f-5; -7.02747f-5 -6.918633f-5 5.851559f-5 -1.0995198f-5 -0.00010000135 1.5412337f-5 5.9258262f-5 0.00017507424 -0.00025291415 -0.00011735865 -3.1869327f-5 -0.00010457928 -1.1268937f-6 9.843948f-5 -4.1814976f-5 5.331735f-5 0.00013505627 -1.585739f-5 -0.00012194872 4.0050807f-5 5.3343996f-5 -4.2418753f-5 8.423917f-5 -5.16253f-6 6.36837f-5 -7.3914576f-5 1.1051161f-5 0.00022883092 -0.0001438095 0.00018193877 4.206311f-5 4.73868f-5; -4.7455113f-5 -6.9852395f-6 -4.6192297f-5 -2.4254106f-5 7.971113f-5 0.00016894755 0.000109112116 8.525668f-5 1.2428433f-5 -1.936297f-6 9.5836265f-5 -1.9926547f-5 -2.9830468f-5 1.7625169f-5 0.00013858016 -3.0404913f-6 -7.500386f-5 2.9206356f-5 -4.5902038f-5 0.00020333497 9.828203f-5 -4.747551f-6 1.6848162f-5 -0.00014519106 0.000101807374 -4.6902646f-6 -9.293096f-6 -6.7268165f-6 8.942089f-5 -0.0002059537 -4.0710664f-5 4.5563633f-5; 3.7418395f-5 0.00011165389 1.6272004f-5 4.889404f-5 -3.5911715f-5 -1.8340435f-5 0.00011937267 6.0646777f-5 2.3654535f-5 9.969632f-5 -2.341573f-5 7.411442f-5 -5.773508f-5 -6.639707f-5 -0.00011213939 -3.3867265f-5 0.000104142164 -7.2901466f-6 0.00022593285 -8.631635f-5 0.0002044126 -9.002176f-5 -4.118794f-5 2.979156f-5 5.8749352f-5 1.4125455f-5 -7.300774f-5 -0.00011529918 -1.3173615f-5 6.483965f-5 5.605993f-5 4.6224963f-5; 0.00017245606 5.2268773f-5 2.9550389f-5 -0.00012292463 4.8042577f-5 -6.921752f-6 0.00014708447 7.8247715f-5 -1.2354358f-5 7.07256f-5 7.2191564f-5 2.2183783f-6 9.824473f-5 1.711015f-5 -2.049585f-6 1.9332236f-5 5.8332498f-5 -3.6169724f-5 -1.473867f-5 9.6011274f-5 -0.00013185077 7.087475f-5 -0.0001160654 5.03267f-5 -0.00018520784 2.7303271f-5 -5.638441f-6 -0.00012577788 -6.572392f-5 5.828816f-5 -7.426266f-5 9.6586955f-5; 8.6191885f-6 -0.00011276409 -9.787721f-5 -0.00018328922 -3.7947975f-5 5.7567864f-5 -6.975176f-5 0.00012177719 6.261904f-5 -7.488384f-5 -4.9574217f-5 0.000107163025 -3.251545f-5 -0.00018028151 -1.1940221f-5 -0.00019667018 -0.000113603244 2.2919394f-5 9.648834f-5 3.4514356f-5 -0.00015365347 0.00013274408 -0.00018255817 0.00014086872 8.9934096f-5 -4.2211337f-5 3.9920607f-5 7.498533f-5 8.013647f-5 8.007657f-5 -3.9207374f-5 -9.102381f-6; -3.062911f-5 -2.3067523f-5 0.000114628805 0.00011872829 -4.6754016f-5 2.7966415f-5 -4.6925423f-5 1.9977848f-5 -4.1287687f-5 4.5922975f-6 -6.476771f-5 -7.12616f-5 3.5172325f-5 -3.420682f-5 -9.060599f-6 -4.9379287f-5 -9.3178955f-5 -6.856182f-5 -2.680662f-5 6.5713255f-5 0.0001413629 0.00011630069 0.000105706335 -0.00017583159 5.190791f-5 -2.989614f-5 7.001818f-5 -0.00013707147 -5.3583302f-5 9.684159f-5 -0.00019529552 -1.4428823f-5; -8.238687f-5 4.9678343f-5 5.692224f-5 -1.2941921f-5 7.272746f-5 1.1544172f-5 0.00016329717 -5.6702076f-5 7.142033f-5 -6.481646f-5 -6.965254f-5 -0.00011650259 3.0490763f-5 3.088319f-5 -6.318527f-5 -0.00027459985 -0.0001861249 3.3955326f-5 -8.792654f-5 -3.5305562f-5 3.776206f-5 -5.450432f-5 3.4908117f-5 2.0986443f-5 -8.393438f-5 -0.00016593793 -0.00010995956 3.6379777f-5 -0.000110340436 -2.2801902f-5 -8.233003f-5 5.7094552f-5; 9.472561f-6 0.00019740338 -0.00010189943 -0.0001242014 -0.00019806213 -8.9243564f-5 -0.00016516108 -0.00010594848 4.203079f-5 3.4962362f-5 -0.00010991782 -6.1923038f-6 5.9564016f-5 -0.0001428539 -0.00013502214 5.9222388f-5 -2.8507711f-5 0.00010104281 0.00019720642 -5.5609426f-6 8.821602f-5 -5.5125052f-5 4.906034f-5 9.079564f-5 -2.3258748f-5 -1.0623791f-5 6.83789f-5 -7.938221f-5 -4.7599216f-5 0.000102743325 4.48901f-5 -3.236005f-5; -6.448193f-5 -8.4971274f-5 2.1625545f-5 8.339987f-5 9.568258f-5 -2.7531958f-6 -0.00019560732 4.391806f-5 7.320108f-6 0.00013537388 -3.3334487f-5 -1.09337825f-5 -5.2960873f-5 -2.1294543f-5 -6.214386f-5 -0.00027202078 -0.00012396742 -2.9872897f-5 9.044571f-5 0.00013660153 3.549737f-5 -4.9973387f-5 1.8327704f-5 -9.523039f-5 9.606121f-5 -4.4693655f-7 9.1296424f-5 1.0898611f-5 -9.873823f-5 6.5817054f-5 -6.0952352f-5 0.00012272279; 9.852909f-5 -9.3202354f-5 -0.00010824892 0.00011316742 6.550552f-5 -3.9781677f-5 0.00016462368 -6.793497f-5 -2.419692f-5 -3.6247064f-5 -8.361686f-6 -9.801765f-6 2.92325f-5 -0.0001727096 -0.00018047413 8.90268f-5 3.506837f-5 2.3454893f-5 -9.383121f-5 -8.5258325f-5 5.5575605f-5 0.00019571916 0.000142054 -1.2468458f-6 -3.7743866f-5 -0.00020572443 0.00010978415 -0.00017746136 3.67762f-5 0.00014892739 -2.2154996f-5 -0.00017683761; 1.9373309f-5 -1.918061f-5 8.529408f-5 -1.28047395f-5 6.331433f-5 0.00018026974 -0.00012294909 0.00030539284 -0.0001302266 -4.016439f-5 3.636107f-5 1.0913011f-5 1.9140374f-5 -6.299686f-5 0.0001709617 9.47405f-5 -3.922633f-5 -0.000111133224 0.00017926384 3.6705133f-5 4.3906835f-5 -6.690753f-5 3.5840236f-5 3.2619002f-5 -6.1540895f-6 6.163629f-5 -0.00020283484 2.1832582f-5 -0.0001171062 -7.7648925f-5 0.0002174454 1.7071216f-5; -0.00012672733 -1.9362424f-5 -6.348565f-5 -5.6361532f-5 9.200136f-5 -7.924215f-5 -1.2660473f-5 4.6657122f-5 -1.1172133f-5 -0.00010459944 0.00020239032 9.860132f-5 5.290432f-5 -0.00013340624 -9.370582f-5 -1.4130605f-6 0.00012412555 -0.00012697138 -2.7795057f-5 0.00011659704 -6.476891f-5 -2.3773273f-5 7.601464f-5 -8.872713f-5 -5.8438127f-5 -1.4173579f-5 3.3984309f-6 -4.47754f-6 0.00010736574 3.0251453f-5 -3.812149f-5 -0.00011319723; 3.371868f-5 0.00013443826 5.1827854f-5 -8.6799766f-5 -8.8833054f-5 6.100437f-5 -9.015614f-5 -0.00015620922 0.00011579238 4.6871108f-5 -0.000111769106 -5.6423327f-5 -4.5123732f-5 2.1397766f-5 -7.7836325f-5 4.2284923f-5 -4.1332292f-5 -0.00010564545 6.337535f-5 2.2464586f-5 1.38632f-5 -7.5812936f-6 -0.00021531287 -6.097358f-5 -0.0001145827 -3.5759425f-5 2.5175295f-6 0.00010256529 -0.00013256994 -7.10266f-5 0.00015515374 0.00012863512; 0.00016938224 -6.1727847f-6 7.351327f-5 6.053199f-5 2.8826125f-5 4.0151746f-5 0.00010414845 6.0681476f-5 -4.6352834f-6 -1.9819632f-5 5.2888325f-5 2.9697878f-5 2.9526132f-5 -8.66012f-5 -2.18068f-5 -6.0854985f-5 -0.0001442393 -1.1877175f-5 2.1036738f-5 -4.178559f-5 4.0596937f-5 5.7049678f-5 0.00015102074 -3.9127764f-5 1.0607519f-5 0.00012553699 -8.207494f-5 -6.0937225f-5 -7.84692f-5 1.824489f-5 -1.8778536f-6 -1.5278107f-5; 9.86986f-5 0.00011178038 -0.00015044518 -6.334768f-5 0.000109204426 7.285886f-5 -9.2600625f-5 0.00011747379 -4.2817584f-5 0.00010111074 0.000106857624 9.835138f-6 -5.1704887f-6 0.00010282027 0.00014313807 0.00012373802 -2.3620047f-5 2.1821546f-5 6.122055f-5 0.00011169796 -0.00010699859 -0.00012207477 2.062985f-5 0.000116058865 8.850792f-5 0.0001196328 -0.00020679123 7.5083095f-5 -3.129029f-7 -3.6197252f-5 0.000121184414 -0.000114104136; -7.602979f-5 -2.5869536f-5 -5.155578f-5 -0.00015892592 -4.9893668f-5 -2.734091f-5 9.5885836f-5 -6.2313006f-6 -5.2972755f-5 -0.00012575419 -3.1938216f-5 7.2837305f-5 -0.00024563057 3.9540293f-5 -2.6104986f-5 -0.00010137934 0.00012463647 -0.00010902507 4.2457566f-5 -8.375032f-5 6.6187444f-5 -5.1533305f-5 1.7981958f-5 -0.000117664174 -1.2011039f-5 -4.5365145f-5 -7.364516f-5 -6.486815f-5 5.8600184f-5 -4.203497f-5 -3.2485477f-5 4.5075005f-5; -0.00013070542 -0.00015975433 -6.992425f-5 0.0002099442 6.1440296f-5 -5.5624467f-5 -0.00010954051 0.00010350125 0.00013992924 5.120921f-5 5.6765402f-5 6.9081885f-5 2.6757074f-5 -5.5876393f-5 -1.511445f-5 -0.00010202999 -9.522961f-5 0.00016932531 9.387505f-5 0.0001063169 -2.5045027f-5 0.0001045822 -4.9932452f-5 -0.00012684085 -8.448334f-6 0.00010434034 -0.00010065335 7.570756f-5 -0.00010008741 7.5593925f-5 4.958557f-5 7.100681f-5; 9.346449f-5 8.065102f-5 -0.0001231683 0.00012845539 3.61032f-5 -0.00010899335 4.0846182f-5 -1.9712192f-5 -2.452599f-5 -0.00029979285 -5.196527f-5 -9.586883f-6 6.468827f-5 4.194355f-5 9.447276f-5 -0.00014939476 0.000119398974 2.9458113f-5 -3.9788f-5 -4.9356026f-5 3.0536183f-5 9.2327784f-5 5.9385846f-5 0.0001586432 -5.486276f-5 -0.00013903994 -2.5824826f-5 -8.27096f-5 5.6255154f-5 -8.366687f-5 6.415546f-5 2.4961846f-5; -1.7154618f-5 -6.25401f-5 0.00013183664 -4.154787f-5 4.6544724f-6 4.7508565f-6 5.167622f-5 1.1392743f-5 -0.00011841812 3.5660407f-5 1.0554406f-5 5.0109506f-5 -4.4523254f-6 8.575221f-5 -9.3798866f-5 3.639836f-5 -1.8728966f-5 -0.00019036532 -7.332643f-5 -7.249745f-5 5.9433336f-5 7.36477f-5 -2.0806767f-5 4.8500628f-5 -0.00011557353 3.885173f-5 9.786825f-5 1.2558773f-5 -7.953113f-5 1.7916214f-5 0.00014363094 -8.705342f-5; -9.4098295f-6 -7.6423064f-5 8.952854f-5 5.7538964f-6 -6.483624f-5 4.0347215f-5 9.717157f-5 -2.31686f-5 8.8374305f-5 -0.00023891937 5.5408844f-5 -0.00016257844 7.194669f-5 -9.041726f-5 4.255883f-5 8.219531f-5 -6.25591f-6 4.08506f-5 -2.4120787f-5 -8.082002f-5 -5.181983f-5 0.00019782949 -8.996561f-5 8.368457f-5 1.49951975f-5 -5.246516f-5 8.9217334f-5 8.673243f-5 4.533018f-5 0.00018186381 3.5822617f-5 -8.640922f-5; 6.031883f-7 9.489666f-5 0.00013656454 -9.544131f-5 -9.683716f-5 -9.127037f-5 0.00010762211 -1.5133531f-6 -6.284508f-5 2.369077f-5 -2.3863591f-5 8.066414f-5 6.553842f-5 0.00011113959 1.0472403f-6 0.00017817711 -0.00014969378 -0.00011678407 -6.9730275f-5 -0.0001974443 9.4495044f-5 5.803256f-5 -5.3002746f-5 7.6416174f-5 -6.3233725f-5 6.117161f-5 -3.9345017f-5 0.00025667725 2.5234429f-5 -2.5647718f-5 5.346883f-6 -9.2683964f-5; -0.00018939369 5.8539154f-5 -0.00010850468 0.0001118518 1.8812738f-5 4.891557f-5 -2.0645693f-5 7.53004f-5 -4.4426426f-5 -4.807725f-5 -0.000111647285 -0.000110212495 -2.360084f-5 -0.00010546466 1.1253425f-5 -0.00017753747 4.8341567f-6 0.000110940826 6.8192516f-5 -0.00010594281 -5.8846894f-5 1.1336939f-5 -1.8298972f-5 -4.3319156f-5 2.1685926f-5 0.00031025952 3.3905286f-5 6.461463f-5 3.649047f-5 -2.9638764f-5 0.00018975354 -3.9982922f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[0.0001708779 4.4241037f-6 1.4100209f-5 1.2816746f-5 -2.5358646f-5 0.00016335303 2.250023f-5 -0.0001626948 -3.2744138f-5 1.3899964f-5 -8.063333f-6 -0.00028158276 -1.1681926f-6 -4.976991f-5 6.505884f-6 -0.00012898148 -8.202538f-5 5.314386f-6 8.4425155f-5 -7.636122f-5 1.57976f-5 -5.9597656f-5 -8.146553f-5 -6.3207815f-5 -6.0987797f-5 9.512302f-5 2.8953767f-5 2.6497093f-5 1.5623647f-5 8.08009f-5 -0.00016826541 -0.00010676208; 1.575427f-5 9.748395f-5 -9.326786f-5 7.175283f-5 5.936708f-5 -0.00015665338 -0.00011727081 -1.6288703f-5 -0.0001014183 3.0831892f-5 0.00012100857 2.612674f-5 6.932443f-5 -0.00012696009 3.233908f-5 -0.00013499355 -9.2992166f-5 -8.3339786f-5 -1.8018718f-5 0.00012028642 -0.00018586835 -0.000112399386 3.4080847f-5 7.0763366f-5 0.00012099564 -3.82022f-5 3.1342017f-5 -9.867962f-5 9.590853f-5 -2.8015153f-5 -5.3827374f-5 -0.00021827056], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most DL frameworks, Lux defaults to using Float32, however, in this case we need Float64

julia
const params = ComponentArray(ps |> f64)

const nn_model = StatefulLuxLayer{true}(nn, nothing, st)
StatefulLuxLayer{true}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),  # 64 parameters
        layer_3 = Dense(32 => 32, cos),  # 1_056 parameters
        layer_4 = Dense(32 => 2),       # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)

    axislegend(ax, [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"]; position=:lb)

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end
loss (generic function with 1 method)

Warmup the loss function

julia
loss(params)
0.0006780852410125551

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end
callback (generic function with 1 method)

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob, BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback, maxiters=1000)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-2.692951238710547e-5; 0.00018981375615104366; 8.521035488228613e-5; -1.673732185736467e-5; 0.00010250499326490161; 8.728337706992472e-5; 8.984874148152734e-5; 1.6952993973940427e-5; 3.2424088658431408e-6; -1.9008968592968413e-5; 8.083607099237094e-5; -7.883311627666762e-5; 2.4738581487305023e-5; -9.487601346324346e-5; -3.355939043100229e-5; 0.0001103253598556752; 4.207101301288157e-5; -7.079802162466305e-5; -1.6894367718092129e-6; 1.7476522771159566e-5; 0.0001454673329134903; -0.00011259165330553503; 5.7274279242798434e-5; -7.214386278053898e-5; 6.567017408078608e-5; -5.330033309282252e-5; -1.569787855260911e-5; 9.638904157326163e-5; -7.44461649446382e-5; -6.256096821746361e-5; 7.22826007403585e-5; -1.1016752068823068e-5;;], bias = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = [2.717752956230945e-5 -7.497707723084014e-5 -1.976724872298384e-6 7.51524055594788e-5 -0.00013121610412054156 5.293594201065786e-5 6.44634926793469e-5 -0.0001951983142426028 -0.00014117240810448105 -2.464673364740089e-5 5.533379481260755e-8 -5.9489259212575686e-5 -1.3832548786026498e-5 -8.772228487214653e-5 -6.9926056603565e-6 0.00015344444524816977 -1.4321928785660376e-5 -0.00013678974958748687 -8.259155604902457e-5 -2.2891308828562012e-5 -3.697579107391688e-5 8.958568076196648e-6 -4.0553295759319904e-5 -0.00011447936433979415 -0.00010526632852998775 9.556107914238992e-6 -0.00013835147568171615 -0.00010410927664634383 -5.761144732969878e-5 2.1885161702802598e-5 9.228439718571388e-5 0.00013523336362412412; -4.58039684239683e-5 2.8403070354741878e-5 0.0001999468441859473 0.00016648644679622692 -1.4793154951288077e-5 -7.729433116804528e-5 0.00010033621958388314 -6.806670706322445e-5 -4.1998125998760064e-5 0.0001510748189680834 -3.5716027814883265e-5 7.95272063620038e-5 6.2547107626926e-5 -0.0002285768640653569 8.098429690977565e-5 -6.942463358466973e-5 -9.5113886236902e-5 0.00015341795380624182 -0.0001175004507608157 -3.920543156749598e-5 -0.00021130798984852703 2.6922693104613844e-5 -5.824709649499408e-5 -5.185273028513239e-5 2.1977657633727883e-5 -2.0285638930033115e-5 9.16971767942977e-5 -0.00011251326930544921 9.799456176512136e-5 -0.000163476964990365 -2.8285145054210987e-5 -0.00012242647779243105; 8.981897558036266e-5 1.1648135155634776e-5 -4.232492206434891e-5 3.6410292564630676e-5 3.1142574657991906e-6 -3.6181452700898e-5 9.105052599401018e-5 -3.7959802243187856e-5 5.288777542367602e-5 -2.528841646786764e-5 -4.384226481712576e-5 -7.802709292669924e-5 0.00014565623864249333 3.543227653202382e-5 2.2587588360272292e-5 -0.00011937481754509106 0.00020654033146050735 0.00015257455383610802 -9.902315646692154e-5 -0.0001434708028268347 -8.990875316584372e-6 9.596141903378683e-5 1.934785363622412e-5 1.3999663720402173e-5 -0.00017774571456763053 0.00011883635310872405 8.640518357927931e-5 -4.534572504277568e-5 -4.446264206714121e-5 -1.6770984033849902e-5 2.2907148418686113e-5 -6.970668614891335e-5; -0.0001727149965334162 0.00010190486229497203 -9.034409138601367e-5 0.00016842639617627096 -6.213020513294222e-5 -3.569515591132535e-5 2.7406776116750526e-5 8.716454342012832e-5 1.8544403265165785e-5 -0.000117457139803976 -9.340467930067456e-5 -1.8855963371406775e-5 -8.439766223071522e-5 -3.840573280994066e-6 0.0002924643604457901 1.9707845258583513e-5 -5.9154025602567494e-5 -0.00010108173863025573 -7.433951666806297e-5 4.123501732484978e-5 -6.962638415622961e-5 -8.463567335618526e-5 4.2962104107772146e-5 -0.00014974514759140268 3.3940196790553285e-5 2.2385102984315676e-5 -4.087129398593141e-5 0.00018560389891738061 -1.7301973439116785e-5 4.3958222722921e-5 -0.00014297150753898725 -9.259268135595396e-6; 2.4924459269339455e-5 -3.927962750179906e-6 4.1480722559458707e-5 -0.00010227053209503761 -1.708179408280684e-5 9.163709981537567e-5 -1.8448282587382006e-5 -0.00011121729695999642 0.0001043992822657558 0.00019963420416631164 -0.00014237005890875088 -8.997824553928333e-5 -5.144223648432754e-6 -7.938036944080797e-5 -2.726993918016118e-5 0.00013993300743339243 2.1798154905508604e-5 7.780995276946364e-6 -0.0002794018066354809 9.784458305390905e-5 1.0026157339690361e-6 -0.0002263994676625944 2.2703886118207076e-5 -0.00017381239669459964 -0.0001127918869472779 2.2707596856589876e-5 -5.792623343707646e-5 1.619759567136252e-5 -7.663264770356069e-5 2.515811391519626e-5 -4.531738815542404e-5 -4.402343549130149e-5; -3.2254534281634295e-5 -0.00010047913337176158 -3.263449934015389e-5 0.0001460379721800954 -2.5251566959143827e-5 -0.00015280710021320094 -3.2024799940211575e-6 4.890172250165313e-5 -4.5061274696452076e-5 3.7636828063505015e-6 -2.059002644377726e-5 3.8561018763040125e-5 -1.9348908145872578e-5 -0.0001377110306667828 -3.929608964531867e-5 -0.00013586718659202286 3.6157511666574405e-5 -0.00027004802495062045 -3.5390879466974654e-5 -3.990615686339616e-5 3.450796476397222e-5 2.09428231652067e-6 -0.00015332280553698454 1.6078446159417334e-5 3.13098394225094e-5 -3.107796464169366e-5 6.919406633536965e-5 8.18747144047442e-5 -3.628821604356726e-5 -3.7663666708933356e-5 -2.7013272748399905e-5 -0.00012548376934424352; -9.7086220694546e-5 -0.00011882454832673608 5.2776350866697835e-5 2.888756848689784e-5 -9.285303945078295e-5 -2.312398514775825e-5 -6.291176082457213e-5 -7.089106343344636e-5 -7.889744715656007e-5 -5.5501165252717566e-5 -1.718539037981809e-5 3.176927152664591e-5 2.5290817074637516e-5 3.5634965979188916e-5 3.259051250053725e-5 -5.30568909657444e-5 7.832802971235635e-5 0.0001436625565925352 6.563512165441325e-5 -6.725184043760262e-5 -0.00016290575563332704 0.00012991246644516706 8.2586327561803e-5 -5.326586010640304e-5 0.00013109140440930792 -0.0001461063553056011 0.0001674238550999869 3.904990034716687e-6 -5.835734605214373e-5 0.00015458688308128332 -4.0967268933381914e-5 -1.9495152219388077e-5; 8.583308030748195e-5 -4.022645421550311e-5 -1.446931778819093e-5 -0.0001868653218062643 2.7690019541234116e-6 -8.322870612946878e-5 8.213527130091894e-5 -6.166420602395862e-6 5.654632500037103e-5 -4.9037777129425974e-5 7.488026844466644e-5 -6.73867563401477e-5 -4.436030408586904e-5 2.4770721070940802e-5 4.109315145424901e-5 2.3031554379386486e-5 0.00011085765967973754 9.245546047705038e-5 2.3896858552606033e-5 0.0001097512775649268 2.208343945201167e-6 -2.960294128034552e-5 8.257419344400098e-5 3.368867860072078e-5 0.00025151023678145277 0.00019277354098986007 8.953206801916925e-5 -3.9320248246535846e-5 0.00011500609056920484 1.014412996110232e-5 2.9467493592074762e-5 0.00013299807128160727; -0.00011059295478910744 9.409183281537271e-5 0.00012769107730425133 -0.00010614082542881891 8.478525729448595e-5 -0.00010153998277085674 -2.8663487447130907e-5 5.1321363807138426e-5 0.00010065079573839393 1.2723980270739996e-5 9.612561576365692e-5 -0.00018997833807655405 7.478953217220891e-5 -8.58311393501319e-5 -9.334789478155403e-5 6.596430490323346e-5 7.367523381551156e-5 4.596482200146902e-5 6.861859606876459e-5 -5.154974836907872e-5 -4.1767878193001475e-5 6.06338219397607e-5 -7.03116308984409e-5 0.00011952970113204106 1.9780213817917335e-6 0.00019986130087933825 9.450609673810603e-5 2.0751984501901187e-5 -8.45244573981571e-5 3.9160339164727436e-5 -0.00013408409699034636 -5.9971611099835054e-5; -6.384690143607171e-5 -7.253420814226408e-6 -1.259109606182631e-5 -0.00023265764550745138 -9.789208209114752e-5 -5.7794621649393125e-5 -0.00013586820071070344 4.603857585022572e-5 -0.00015011266307413254 6.7292695779152e-5 -6.470583550837849e-5 5.607806221958194e-5 -4.729722652557863e-5 -2.8120501634858484e-5 5.120045668588441e-5 0.0002435278488279275 3.822246569468187e-5 -0.00011746327302391806 -0.00016014635414324376 -0.00011821948512663717 -9.724634812883687e-5 7.316282766244226e-6 -8.169553123154996e-5 -8.215957725626565e-5 3.355973737923203e-5 0.00011675734046873278 0.0001487189779014175 -8.736858514039921e-6 -2.8517690704051637e-5 0.00018830204996772782 -6.714500571108569e-5 -5.562265007613826e-5; -7.027416052238255e-5 -6.918579367862885e-5 5.851612820464831e-5 -1.0994660102875662e-5 -0.00010000081295089209 1.5412875085023098e-5 5.9258799750243095e-5 0.00017507477548264736 -0.0002529136114631516 -0.00011735811188060023 -3.1868789507423475e-5 -0.00010457874361832894 -1.1263558704347163e-6 9.844001703953985e-5 -4.181443785142331e-5 5.331788940879142e-5 0.00013505680487782505 -1.5856851520363846e-5 -0.00012194817884533824 4.0051344852548325e-5 5.334453396557296e-5 -4.2418215004144343e-5 8.423970518692181e-5 -5.1619923638624e-6 6.36842372600245e-5 -7.391403835441193e-5 1.105169883289213e-5 0.0002288314566371278 -0.00014380896812013608 0.00018193930679757855 4.206364917394454e-5 4.738733639300642e-5; -4.745306957407729e-5 -6.983196517820089e-6 -4.619025437056408e-5 -2.425206345866278e-5 7.971317084378019e-5 0.00016894958956039166 0.00010911415868543643 8.525872383837744e-5 1.2430476299086002e-5 -1.9342540057197973e-6 9.583830817940372e-5 -1.9924503949701493e-5 -2.982842467678111e-5 1.7627211700193836e-5 0.00013858220175246805 -3.0384483310354975e-6 -7.50018168858899e-5 2.9208398909427414e-5 -4.58999945934565e-5 0.0002033370167097371 9.828407039969818e-5 -4.745508040623909e-6 1.6850204910621487e-5 -0.0001451890218689776 0.00010180941738293385 -4.6882216978366065e-6 -9.291052965964935e-6 -6.724773602734734e-6 8.942293627991137e-5 -0.00020595164985057988 -4.070862066361304e-5 4.5565675672732094e-5; 3.742072195421557e-5 0.00011165621882210214 1.6274331672709314e-5 4.889636594128481e-5 -3.590938749000599e-5 -1.8338108103828904e-5 0.0001193749984926639 6.0649103860153464e-5 2.365686205706267e-5 9.969864814450646e-5 -2.3413403224887827e-5 7.411674497091374e-5 -5.773275287680667e-5 -6.639474403671043e-5 -0.00011213705974256995 -3.3864937987870724e-5 0.00010414449107741509 -7.287819327367259e-6 0.00022593517397438814 -8.631402405768946e-5 0.00020441492861974928 -9.001942916400137e-5 -4.1185613637544064e-5 2.9793886741052904e-5 5.875167963352363e-5 1.412778229535513e-5 -7.3005409914519e-5 -0.00011529685533930623 -1.3171287802444037e-5 6.484197993670929e-5 5.6062256506655895e-5 4.6227290029885955e-5; 0.00017245743127276082 5.227014751723195e-5 2.9551762946337162e-5 -0.00012292325211021206 4.804395119279073e-5 -6.9203779450822585e-6 0.0001470858469299649 7.824908878390592e-5 -1.235298352113999e-5 7.072697097171941e-5 7.21929381877799e-5 2.219752542058065e-6 9.824610513604144e-5 1.7111523430482327e-5 -2.0482107349555835e-6 1.9333609980799526e-5 5.8333872385247e-5 -3.6168349669625255e-5 -1.4737295731919157e-5 9.60126480829614e-5 -0.00013184939691130175 7.087612082684694e-5 -0.00011606402496838233 5.0328074376719705e-5 -0.00018520646792169 2.7304645282838337e-5 -5.637066823709441e-6 -0.00012577650433670293 -6.572254735076194e-5 5.82895326995452e-5 -7.426128378867482e-5 9.658832912535329e-5; 8.617544718097565e-6 -0.0001127657373031534 -9.787885153816038e-5 -0.00018329085958573372 -3.794961901764166e-5 5.756622033043827e-5 -6.975340454989026e-5 0.00012177554946940989 6.261739383301348e-5 -7.488548670114348e-5 -4.95758607754111e-5 0.0001071613811476993 -3.2517094386614315e-5 -0.00018028315333100826 -1.1941864958836809e-5 -0.0001966718258514131 -0.00011360488804673102 2.291775001179348e-5 9.648669276469038e-5 3.451271242777482e-5 -0.0001536551129716462 0.00013274244031020777 -0.00018255981502040465 0.00014086708017242074 8.993245193817537e-5 -4.221298093359183e-5 3.9918963575862695e-5 7.498368415366079e-5 8.013482931385698e-5 8.007492635481853e-5 -3.920901815906235e-5 -9.104025029566796e-6; -3.063001229619017e-5 -2.3068424232925637e-5 0.00011462790400530064 0.00011872738952790034 -4.675491680116443e-5 2.7965513546786775e-5 -4.692632381064009e-5 1.9976946824878635e-5 -4.128858807105933e-5 4.5913964835717056e-6 -6.476860897108152e-5 -7.12625030051469e-5 3.517142369802698e-5 -3.420772070838233e-5 -9.061500057744362e-6 -4.9380187809746716e-5 -9.317985557443772e-5 -6.856272255261642e-5 -2.6807520660971195e-5 6.571235400898961e-5 0.00014136200599880228 0.00011629978809780456 0.0001057054335628144 -0.0001758324930766153 5.1907010239500074e-5 -2.989704142108372e-5 7.001727619482781e-5 -0.0001370723754722168 -5.35842033774146e-5 9.684069056804044e-5 -0.00019529642184223258 -1.4429724322856734e-5; -8.239049225908648e-5 4.967471954204702e-5 5.69186156489691e-5 -1.2945545001507219e-5 7.272383640425577e-5 1.1540548023690456e-5 0.00016329354524983103 -5.67056997688766e-5 7.141670334289755e-5 -6.482008053085751e-5 -6.965616217443896e-5 -0.00011650621332323394 3.0487138922531605e-5 3.087956769646264e-5 -6.318889086929045e-5 -0.0002746034736056017 -0.00018612852750907465 3.395170255585214e-5 -8.79301662799253e-5 -3.5309185711461966e-5 3.775843628588775e-5 -5.450794220181556e-5 3.4904492843406185e-5 2.0982819244592812e-5 -8.393800113212762e-5 -0.00016594155773150423 -0.0001099631847474072 3.637615349630977e-5 -0.00011034405930063051 -2.280552531105957e-5 -8.233365247819909e-5 5.7090928515158455e-5; 9.471364698985314e-6 0.00019740218396460183 -0.00010190062619818097 -0.00012420259316291659 -0.00019806332728199258 -8.924476004045166e-5 -0.00016516227640982412 -0.00010594967478260396 4.2029595185741675e-5 3.496116607260895e-5 -0.00010991901525135834 -6.19350010543361e-6 5.95628194946052e-5 -0.00014285509258348154 -0.00013502333911906894 5.922119145675101e-5 -2.8508907582871015e-5 0.00010104161321345903 0.0001972052237919717 -5.562138887906891e-6 8.821482025946383e-5 -5.512624864655058e-5 4.905914521936447e-5 9.079444554794774e-5 -2.3259944448506725e-5 -1.0624987573971882e-5 6.83777003518492e-5 -7.938340730338932e-5 -4.760041282002605e-5 0.00010274212819917271 4.488890464694435e-5 -3.236124701521364e-5; -6.448271306177928e-5 -8.49720570858696e-5 2.1624761937375465e-5 8.339908817836678e-5 9.568179957389759e-5 -2.7539786533158897e-6 -0.00019560810760773853 4.3917275614176304e-5 7.319325230155425e-6 0.00013537309423375537 -3.3335269838033135e-5 -1.0934565399991797e-5 -5.296165618526146e-5 -2.129532622526978e-5 -6.214464229012223e-5 -0.000272021567800008 -0.00012396819947932216 -2.987368025239307e-5 9.044492908108636e-5 0.00013660075201008128 3.549658795057995e-5 -4.997416981674749e-5 1.8326921415819983e-5 -9.523117187379137e-5 9.606042585622284e-5 -4.4771944837435717e-7 9.129564132129435e-5 1.0897828540721789e-5 -9.873901290299572e-5 6.581627094240615e-5 -6.095313510920131e-5 0.00012272200838017865; 9.852823038969823e-5 -9.320321352999877e-5 -0.0001082497774608165 0.00011316656216664862 6.550465966726637e-5 -3.9782536098147814e-5 0.00016462281835417176 -6.793582817620907e-5 -2.4197778677566348e-5 -3.624792318290407e-5 -8.362545506617892e-6 -9.80262486549001e-6 2.9231640817366343e-5 -0.00017271045455557357 -0.00018047499272355163 8.902593888494216e-5 3.506750997773105e-5 2.3454033887598788e-5 -9.383206727064532e-5 -8.525918476364884e-5 5.557474558160929e-5 0.0001957182966886372 0.00014205314708093885 -1.2477053336012713e-6 -3.774472590733514e-5 -0.00020572528685291536 0.00010978329280775661 -0.00017746222240232914 3.67753409388273e-5 0.0001489265259603297 -2.2155855751709827e-5 -0.00017683846910798305; 1.93756879220766e-5 -1.9178230862551615e-5 8.52964587161083e-5 -1.2802360136918099e-5 6.331670753431703e-5 0.0001802721164189973 -0.000122946708719177 0.0003053952170071012 -0.0001302242233882836 -4.016200894831624e-5 3.636345005950466e-5 1.0915389945218993e-5 1.9142753596026813e-5 -6.299448417259998e-5 0.0001709640857065841 9.47428781423445e-5 -3.922394883695113e-5 -0.00011113084459204535 0.0001792662152788424 3.67075119051902e-5 4.390921473392583e-5 -6.690515131915893e-5 3.58426151856111e-5 3.262138140278547e-5 -6.151710143531864e-6 6.163866798167909e-5 -0.00020283245961174307 2.183496159566392e-5 -0.00011710382371668119 -7.764654568156448e-5 0.00021744777783758618 1.7073594941832874e-5; -0.00012672850356950516 -1.936359444271897e-5 -6.348681858555254e-5 -5.636270293082217e-5 9.200019190273536e-5 -7.924331952865061e-5 -1.2661643491192177e-5 4.665595111389538e-5 -1.117330361496309e-5 -0.00010460061389069177 0.00020238915090401794 9.860014757094095e-5 5.290315016707313e-5 -0.0001334074106537807 -9.370699194307514e-5 -1.4142312878160426e-6 0.00012412438150447743 -0.0001269725537397984 -2.7796227409747672e-5 0.00011659586812554751 -6.477007922996065e-5 -2.3774443657492295e-5 7.601347238476396e-5 -8.872830153024672e-5 -5.843929761369663e-5 -1.4174749817650376e-5 3.3972601201426842e-6 -4.4787109303066145e-6 0.00010736456966213048 3.0250281956503504e-5 -3.812265935750595e-5 -0.0001131984020006016; 3.3716824054084356e-5 0.00013443640645997878 5.182599861577519e-5 -8.680162127425956e-5 -8.883490945740969e-5 6.1002516394317304e-5 -9.015799869380476e-5 -0.00015621107732009207 0.00011579052816081713 4.6869252491108474e-5 -0.00011177096157630585 -5.642518217505586e-5 -4.5125587258423184e-5 2.1395910243108004e-5 -7.783817991444854e-5 4.228306724976485e-5 -4.13341475913112e-5 -0.00010564730464543919 6.33734963222585e-5 2.2462731157377667e-5 1.3861345086440551e-5 -7.583148909223957e-6 -0.00021531472875799867 -6.0975435099729735e-5 -0.00011458455255783911 -3.5761280651836655e-5 2.5156742587050333e-6 0.00010256343384676173 -0.00013257179662182672 -7.102845838822264e-5 0.00015515188724670798 0.00012863326273795142; 0.0001693837130946873 -6.171312994898807e-6 7.351474241975583e-5 6.053346269571452e-5 2.882759700103393e-5 4.01532180180574e-5 0.00010414992184767348 6.068294724489791e-5 -4.6338117467867204e-6 -1.9818160486936175e-5 5.2889796373601234e-5 2.969934949167477e-5 2.9527604150173523e-5 -8.659972617944901e-5 -2.180532819376605e-5 -6.085351369304003e-5 -0.00014423782513087456 -1.1875703327422985e-5 2.1038209618821372e-5 -4.17841200910672e-5 4.0598408398611004e-5 5.705114936379339e-5 0.00015102220832293936 -3.912629281014216e-5 1.0608990217843414e-5 0.0001255384578175667 -8.207346936298302e-5 -6.093575384195174e-5 -7.846773032408236e-5 1.824636101762816e-5 -1.8763819494050205e-6 -1.5276635073421017e-5; 9.87018015107249e-5 0.00011178358039927791 -0.00015044197567160781 -6.334447749079001e-5 0.00010920762946941005 7.286206022715549e-5 -9.259742142287707e-5 0.00011747699540454167 -4.281438082809326e-5 0.00010111394153089458 0.00010686082754862035 9.838342014949686e-6 -5.167285104579692e-6 0.00010282347142808345 0.00014314126969427917 0.0001237412238365026 -2.361684306062712e-5 2.1824750041083237e-5 6.122375129369321e-5 0.0001117011656274201 -0.00010699538466409938 -0.00012207156174238955 2.0633053685267234e-5 0.00011606206903113663 8.85111243801262e-5 0.00011963600485931081 -0.000206788024621818 7.508629864193673e-5 -3.0969925458054724e-7 -3.6194048840591645e-5 0.00012118761737244813 -0.0001141009320749786; -7.603381637165182e-5 -2.5873561030280228e-5 -5.155980507407266e-5 -0.00015892994379538961 -4.989769259264193e-5 -2.73449342828884e-5 9.58818116443529e-5 -6.235325331378612e-6 -5.297677962955043e-5 -0.00012575821274805722 -3.194224082545402e-5 7.283328049669907e-5 -0.000245634591196737 3.953626837319064e-5 -2.6109011018721326e-5 -0.00010138336183881596 0.00012463244529061299 -0.0001090290981708769 4.245354175071276e-5 -8.375434227202004e-5 6.618341903523589e-5 -5.153732964099529e-5 1.797793338045024e-5 -0.00011766819916811971 -1.2015063661975705e-5 -4.536916931539355e-5 -7.364918674722238e-5 -6.487217724934812e-5 5.859615957450898e-5 -4.2038996257189646e-5 -3.2489501977406555e-5 4.5070980190924026e-5; -0.00013070404497605127 -0.00015975295650472917 -6.992287785445018e-5 0.0002099455803104088 6.144167014233012e-5 -5.562309274730899e-5 -0.00010953913556343592 0.00010350262050585529 0.00013993061179596653 5.12105823431507e-5 5.676677642574889e-5 6.908325917854343e-5 2.6758448031136483e-5 -5.5875019141721915e-5 -1.5113076413877829e-5 -0.00010202861562768829 -9.522823643103603e-5 0.00016932668836622922 9.387642671888414e-5 0.00010631827058339196 -2.504365261168994e-5 0.00010458357314880082 -4.993107836392206e-5 -0.00012683947829694072 -8.446959553591546e-6 0.00010434171304175221 -0.00010065197534324894 7.570893535691543e-5 -0.00010008603680810493 7.559529945089955e-5 4.958694512558583e-5 7.100818652495456e-5; 9.346431475124224e-5 8.06508440756343e-5 -0.00012316847564661004 0.0001284552098251031 3.6103022331287674e-5 -0.00010899352778223402 4.084600445922409e-5 -1.9712369202255343e-5 -2.452616822565076e-5 -0.00029979303067092876 -5.196544875515689e-5 -9.587060184360674e-6 6.468809156204476e-5 4.19433716625529e-5 9.447258057762773e-5 -0.0001493949413533232 0.00011939879627889861 2.9457935716988725e-5 -3.978817689939504e-5 -4.93562030432197e-5 3.053600526173621e-5 9.23276064450943e-5 5.93856683291114e-5 0.00015864302699911337 -5.486293880393806e-5 -0.00013904011861895295 -2.5825003057337438e-5 -8.270977846122395e-5 5.6254976562838306e-5 -8.366704710069191e-5 6.415528046191598e-5 2.4961668578823015e-5; -1.7154920050516864e-5 -6.254040278609597e-5 0.00013183633921552405 -4.154817367553822e-5 4.654170362200729e-6 4.750554517968176e-6 5.1675917285037826e-5 1.1392440620410538e-5 -0.00011841841848682727 3.566010485054815e-5 1.0554103874610247e-5 5.010920352186956e-5 -4.4526273714936455e-6 8.575190798263974e-5 -9.379916816134108e-5 3.639805793763146e-5 -1.8729268112211526e-5 -0.00019036561940912223 -7.332673346273716e-5 -7.249775450792274e-5 5.943303397624851e-5 7.364739934099852e-5 -2.080706878506194e-5 4.8500325584327185e-5 -0.00011557383192585852 3.885142717162061e-5 9.786794901709892e-5 1.2558471035745247e-5 -7.953143000951096e-5 1.7915912331672568e-5 0.00014363063740428474 -8.705372244468072e-5; -9.408712702323541e-6 -7.64219476137648e-5 8.952965798142838e-5 5.755013245463211e-6 -6.4835123356363e-5 4.034833180538838e-5 9.717268765724706e-5 -2.3167483976620414e-5 8.837542189338985e-5 -0.00023891825186966192 5.5409960606438446e-5 -0.00016257731842289172 7.194780844783196e-5 -9.041614613907934e-5 4.255994643371319e-5 8.21964240289043e-5 -6.254793350862115e-6 4.085171529494781e-5 -2.4119670359694455e-5 -8.081890338745169e-5 -5.1818711670982105e-5 0.0001978306063223378 -8.996449107017683e-5 8.368568886125077e-5 1.4996314331610911e-5 -5.246404181762895e-5 8.921845072235713e-5 8.673354385008503e-5 4.5331296843314323e-5 0.0001818649288737926 3.5823733907263614e-5 -8.640810485176937e-5; 6.039125568111023e-7 9.489738687613173e-5 0.00013656526818146814 -9.544058746917123e-5 -9.683643628170906e-5 -9.126964476684232e-5 0.00010762283301653206 -1.5126288955126646e-6 -6.284435556529032e-5 2.3691495076376522e-5 -2.386286695048311e-5 8.066486435916184e-5 6.553914528546898e-5 0.00011114031468957505 1.0479645342217143e-6 0.00017817783765701345 -0.00014969305237548408 -0.00011678334311375454 -6.97295505998464e-5 -0.0001974435760386097 9.449576856774758e-5 5.803328560056851e-5 -5.300202218143289e-5 7.641689836952772e-5 -6.323300084125294e-5 6.117233380416399e-5 -3.934429304001799e-5 0.0002566779697334563 2.523515314885341e-5 -2.5646993603119975e-5 5.347607217001996e-6 -9.268323963999693e-5; -0.0001893939063571052 5.8538933199494514e-5 -0.00010850490277294335 0.00011185157721269562 1.881251749222857e-5 4.891534784332587e-5 -2.0645913782168916e-5 7.530017840549036e-5 -4.442664637240899e-5 -4.807747088266851e-5 -0.00011164750528183105 -0.00011021271554414455 -2.3601060338074685e-5 -0.00010546487671450033 1.12532049847177e-5 -0.00017753768976904683 4.8339363551749625e-6 0.00011094060549152734 6.819229541219454e-5 -0.00010594303082103223 -5.8847114150516467e-5 1.1336718426214137e-5 -1.8299191896909064e-5 -4.331937659076694e-5 2.168570584257778e-5 0.00031025930159598285 3.390506547515896e-5 6.461440873903904e-5 3.649024959526679e-5 -2.9638984660374104e-5 0.00018975331998303255 -3.998314265987221e-5], bias = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = [-0.0005248472811231692 -0.0006913013529506758 -0.0006816252714896892 -0.0006829087105522275 -0.0007210839222965911 -0.0005323721207250484 -0.0006732252664838436 -0.0008584198772609422 -0.0007284696083968272 -0.0006818253536356079 -0.0007037888224367318 -0.0009773081433891331 -0.0006968935550424883 -0.0007454953616756276 -0.0006892195465210763 -0.0008247069591883644 -0.0007777505454105558 -0.0006904110764918679 -0.000611300327104867 -0.0007720866968465896 -0.0006799277628604311 -0.0007553231193260155 -0.0007771909379542242 -0.0007589332573440018 -0.000756713031282202 -0.0006006020964642337 -0.0006667716834228024 -0.000669228403166707 -0.000680101847379529 -0.000614924568360049 -0.0008639908944574981 -0.0008024875791548495; 0.00027745574329003783 0.00035918552458351805 0.00016843372460754322 0.0003334544059330008 0.00032106858861403776 0.00010504808613556255 0.0001444307827327146 0.00024541272755581337 0.00016028328049136543 0.0002925334151451333 0.0003827101558794211 0.0002878282868121233 0.0003310259689391 0.0001347414820684727 0.0002940406466051116 0.00012670803423300302 0.00016870930075242162 0.00017836179097347816 0.00024368286671002306 0.0003819880053999689 7.583318601299174e-5 0.00014930219136352166 0.00029578240472079336 0.00033246493579225036 0.0003826971306870024 0.00022349924581561406 0.00029304358978971814 0.00016302196776355686 0.00035761011942555827 0.00023368642652122864 0.00020787421074262703 4.343102496395716e-5], bias = [0.0, 0.0]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(compute_waveform(
    dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(ax, tsteps, waveform_nn_trained; marker=:circle,
        alpha=0.5, strokewidth=2, markersize=12)

    axislegend(ax, [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb)

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.10.6
Commit 67dffc4a8ae (2024-10-28 12:23 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
  JULIA_CPU_THREADS = 16
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 16
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.