Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux,
    ComponentArrays,
    LineSearches,
    OrdinaryDiffEqLowOrderRK,
    Optimization,
    OptimizationOptimJL,
    Printf,
    Random,
    SciMLSensitivity
using CairoMakie
Precompiling OrdinaryDiffEqLowOrderRK...
   4050.3 ms  ✓ OrdinaryDiffEqLowOrderRK
  1 dependency successfully precompiled in 4 seconds. 98 already precompiled.

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end
one2two (generic function with 1 method)

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end
soln2orbit (generic function with 2 methods)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end
d_dt (generic function with 1 method)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio  1 "mass_ratio must be <= 1"
    @assert mass_ratio  0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end
compute_waveform (generic function with 2 methods)

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defiing a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(
    Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[7.860439f-6; -5.2525433f-5; 0.00013166785; 7.766766f-5; -0.00014561972; -0.00017803736; -2.7755737f-5; 0.00010047881; 6.7428086f-6; -0.0001594555; -9.549311f-5; 4.0078412f-5; -9.89078f-5; 0.00018193704; 8.555898f-5; -4.970338f-5; -3.8765254f-5; -1.3489131f-5; -0.00013164493; 4.297612f-5; 1.1798019f-5; -0.00011137081; 3.6243822f-5; -9.186421f-5; -6.508293f-5; 4.663855f-5; -1.4155018f-5; 0.00016139026; -0.0001249956; -8.825055f-5; 6.864034f-6; -0.000113623086;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-0.00016437678 -3.0016356f-5 0.00021092324 -1.9295398f-5 6.400007f-5 3.5518326f-5 -7.671253f-5 7.6223337f-6 -1.7718496f-5 -0.00013926686 -1.9196865f-5 -0.00014257817 4.0983534f-5 7.9018006f-5 -0.00010902154 -8.941748f-5 8.6309075f-5 0.00011615165 0.00013435379 -7.870236f-5 -8.417054f-5 -6.1048006f-5 -0.000110034176 -1.0304273f-5 0.00010379674 0.000118325755 0.00018107906 -5.199068f-5 -0.00013425709 -1.3520461f-5 -7.018773f-5 -8.1540806f-5; -0.0001138165 0.00013198743 6.216527f-5 0.00010475709 -0.00012259047 -3.5811176f-5 -5.9808615f-5 4.6188074f-5 0.0001017403 8.599425f-5 2.4665938f-5 0.00016514427 0.000119936565 0.00017141782 0.00018569987 8.3548126f-5 -0.000119457975 -0.000102200815 -2.5700943f-5 -3.5522182f-6 -0.00014269249 -1.3331167f-5 9.086315f-6 -3.1403084f-5 -0.00010148137 -8.164521f-5 -7.184492f-5 9.113478f-5 -5.43189f-5 -0.0001400215 -3.2325417f-5 -6.1269043f-6; 9.969454f-6 5.1681636f-5 0.0001323837 -6.091742f-5 -2.84636f-5 -5.0008377f-5 -1.2507815f-5 -6.465501f-5 1.4994268f-6 -0.00013679356 2.4531822f-5 0.00012220755 0.00018068936 4.203786f-5 -3.4607397f-6 -0.00023079172 5.7677495f-5 0.00012361121 9.097084f-5 0.0001160965 -0.00015204407 0.00011802255 -4.591044f-5 3.2231765f-5 6.239905f-5 -3.0868632f-5 -0.00014959059 9.025776f-5 -4.874964f-5 3.5693734f-5 4.917927f-5 0.00010072045; -2.0711632f-6 -3.1827927f-5 0.00013867537 0.00013392804 8.84111f-5 6.3304474f-6 -4.025227f-5 -0.00011997796 -0.0001582737 8.382454f-5 3.6359375f-5 -9.078856f-5 -4.752343f-5 0.00019082347 0.000115616705 7.918959f-5 -6.270233f-5 -0.00022068835 0.00016179885 2.8037526f-5 9.4876465f-5 5.1381292f-5 4.4210457f-5 7.896579f-5 -6.684568f-5 5.9662198f-5 -7.461792f-5 1.0225014f-5 -4.5663954f-5 -2.357594f-5 -7.383484f-5 -3.5562636f-5; -0.000116719355 0.00016689862 -0.000100075136 8.6156906f-5 0.00010011918 2.8148857f-5 5.4008928f-5 7.3757816f-5 6.9377784f-5 5.9282804f-5 1.8332275f-5 -0.00015216526 0.0002312166 4.3590888f-5 -0.00019778108 -3.14341f-5 -3.1108688f-5 0.00014344706 -0.00011801217 2.2652677f-5 -1.20078785f-5 0.0001323587 -2.9821753f-5 -0.0002715281 -5.2604533f-5 -7.438716f-5 -0.0002244406 0.000118518656 1.0915055f-5 7.141488f-5 7.474548f-5 -0.00011599659; 0.00018576678 9.744388f-5 -2.086525f-5 -0.00010667362 2.5712141f-5 -1.9456742f-5 -1.2227168f-5 7.856932f-6 0.00014285228 1.5197984f-5 0.00010530933 -9.205047f-5 5.8769456f-5 -0.00027841583 1.1447948f-5 -5.0622653f-5 3.1692068f-6 -5.100443f-5 -5.075083f-5 0.00010642333 -3.5123762f-6 -9.171245f-6 -5.301154f-5 -9.309107f-6 -0.0001233679 -0.00015029332 -0.00016648695 -3.7564741f-6 0.00010693836 -4.192232f-5 -6.602926f-5 -1.7919037f-5; -0.00010453267 0.00015391129 9.915401f-5 8.348953f-5 -7.9054924f-5 5.6505087f-5 0.00012845981 1.7698028f-5 1.3949116f-5 -9.806561f-5 -6.420396f-5 0.0002217041 4.7849775f-5 -5.236996f-5 4.9019938f-5 -6.746362f-5 0.00014889111 5.7057696f-5 -0.00010691402 0.0001795911 -4.0313258f-5 8.079401f-6 7.2666726f-6 -5.686987f-6 6.0907998f-5 -5.126703f-5 8.3842715f-5 0.00016367433 -5.5189397f-5 -0.0001455218 -0.00010544503 -0.00015418504; 0.00022932114 9.020424f-5 3.0043677f-5 -5.2220155f-5 0.00010644842 -5.0049846f-5 3.6261197f-5 -3.728025f-5 2.4976409f-5 -0.00023991911 -0.00021004882 -2.6648035f-5 -0.00015839229 4.5194247f-5 -0.00010322864 -9.639724f-5 0.00012596064 1.3371975f-5 -0.00015947327 -6.847464f-5 9.8559f-5 -3.992497f-5 -3.391051f-5 3.467761f-5 -5.4849126f-5 -4.343793f-5 -0.00010100763 -0.00013983452 0.00018997362 -0.00026679982 -0.00013637258 5.1325955f-5; -7.251158f-5 -2.21497f-5 -5.078739f-5 8.177907f-5 -0.00014436334 6.0707163f-5 -4.0421226f-5 0.0001102568 0.00015287085 6.8625064f-5 -7.154448f-5 7.953224f-5 0.0001660816 0.00014572355 0.000100684876 -5.7809317f-5 -0.00022931145 -0.00020820695 -6.144196f-5 -0.00012487781 -5.33979f-5 9.256115f-5 1.6038002f-5 0.0001195173 4.6324552f-5 -0.00027295144 7.2264105f-5 8.911913f-5 -6.401373f-7 0.00014419894 0.00012912118 1.9567065f-5; 6.820378f-5 2.12501f-5 0.0001519751 0.0001277628 -4.5427285f-5 9.45922f-5 2.8407487f-5 3.679055f-5 -9.1179245f-5 -8.759315f-5 -4.233391f-5 -1.0148495f-5 -2.8013366f-5 0.00019726023 4.0231207f-5 -2.0327785f-5 0.000199447 -0.000111626985 -0.00013777778 -0.000117939824 -6.2143245f-6 0.00020680211 -1.3994887f-5 8.546733f-5 0.000104675535 3.492703f-5 -8.7907225f-5 0.00010485748 7.1407274f-5 0.00011894936 0.00013858061 -2.8788549f-5; -1.5867355f-6 4.1940072f-7 0.00021171928 4.52169f-6 6.012732f-5 9.992477f-5 -3.561707f-5 -6.138835f-5 -0.00017997247 6.2459614f-5 -1.34978145f-5 -1.5011756f-5 -0.00012783983 0.00011435254 0.00019095208 2.1907656f-5 -4.116716f-5 -1.17440095f-5 7.0702874f-5 -9.540582f-5 7.0204645f-5 0.00013954662 0.00010261999 -0.0001211858 -7.0892784f-5 -0.0001676153 -1.7646533f-5 -3.4530327f-5 -2.2219998f-5 -2.18182f-5 -4.4124412f-5 0.000111558154; 0.00016322506 0.00014593737 3.3885233f-5 0.0001275012 8.391951f-5 0.00018750902 8.676476f-6 0.00020314452 2.0860696f-5 1.0924117f-5 -0.0001489237 -3.377215f-5 0.0001377311 -6.447554f-5 1.8870135f-6 -1.525362f-5 -0.00018388667 -0.00020160212 7.516864f-5 -8.7843495f-5 -9.126322f-5 -1.8796603f-5 -1.542366f-5 2.639844f-5 -5.7380897f-5 -4.7694493f-6 0.00019579113 0.00010669442 -2.5035256f-6 -5.628761f-5 -0.00012336455 -1.8165401f-5; 4.386997f-5 3.080651f-5 0.00018174354 -2.854481f-5 -5.467102f-5 3.1228672f-5 5.778987f-5 -3.1693147f-5 -5.9755175f-6 7.561135f-5 4.518769f-5 6.331077f-6 -5.9239874f-6 -0.00021513706 1.7105978f-5 -2.608389f-5 6.7201822f-6 9.437318f-5 9.7142896f-5 3.587049f-5 2.990647f-5 9.455785f-5 0.00013769187 7.590747f-5 0.00012126878 -4.2956963f-5 3.1407562f-5 0.00016296902 -2.3624254f-5 5.0848314f-5 2.0582893f-5 -9.444387f-5; -9.685379f-5 4.364414f-5 -8.146248f-5 0.00012298896 -0.0001297601 5.0669637f-6 -0.00020389126 -8.497968f-5 -3.2198812f-5 5.7340447f-5 5.6329452f-5 0.00012570899 -6.903591f-5 9.564116f-5 0.0001340695 -4.0525847f-6 -0.00014826849 -9.674515f-5 -0.00012002463 -1.9548655f-5 -6.245278f-5 -0.00010349815 7.256741f-6 5.4170167f-5 -9.72962f-5 -0.00011858146 -4.4183424f-5 3.905559f-5 7.1765775f-5 2.2052234f-5 -9.916257f-5 -0.00018513115; 2.0909301f-5 -4.8683356f-5 -8.092678f-5 8.124017f-5 3.441194f-5 -0.00010688532 -0.00017534495 -9.5536416f-5 5.324465f-5 1.4815658f-5 -2.5158256f-6 2.1041446f-5 -0.00013107157 0.00015172818 9.03989f-5 1.0895201f-5 -3.224868f-5 -0.00013421122 5.294083f-6 3.5997127f-5 6.929181f-6 2.8395696f-5 2.9743112f-5 -0.0001417928 6.4311724f-5 -0.00013529892 0.00019494956 1.6890172f-5 -0.00019581386 -0.0001004215 0.00013518709 0.0002651039; 0.00011439457 -0.00017873818 -8.155978f-5 -2.3657272f-5 5.2572363f-5 1.871703f-5 -8.4813604f-5 -0.00020245861 -6.970776f-5 0.00015375929 -0.00022528024 1.1745393f-5 2.5853018f-5 -1.4499828f-6 0.00012424964 2.6088617f-6 -3.924224f-5 -1.40888005f-5 -0.00012934036 -4.0908646f-5 -0.000109726425 7.257933f-5 0.00026414773 8.74527f-6 3.3685497f-5 9.373089f-5 -2.3758246f-5 3.0599972f-6 -9.0555855f-5 0.00014242473 -6.925916f-6 9.393081f-5; 0.00018411597 -0.00010941034 0.00015787876 9.781739f-5 0.00012580647 -3.7076934f-5 5.9199956f-5 5.7521836f-5 6.913701f-5 -1.6429325f-5 -8.043186f-5 -6.6614945f-7 6.404517f-5 0.0001124876 -8.0944104f-5 1.560465f-5 -0.00016145661 -0.00020957802 0.00012761129 0.00012393577 8.609023f-5 0.00027474988 -9.616394f-5 -9.7552016f-5 9.3259354f-5 8.2833896f-5 0.00016256991 5.094344f-5 -0.00016081546 0.00013049027 2.501641f-5 9.764774f-6; 4.4500546f-5 -0.00011666155 -5.058511f-5 -0.0001115021 -0.00020778585 4.6901692f-5 2.2114928f-5 -2.4246567f-5 -7.744055f-6 1.7532006f-5 -9.8236895f-5 -4.661263f-5 7.069244f-5 -2.470423f-5 8.236617f-5 -1.9313466f-5 7.0327774f-6 6.528813f-5 -6.1844024f-5 5.0168937f-6 0.00016198569 -0.00014750911 -3.796759f-5 7.6624834f-5 0.00013080529 0.000107561624 8.559652f-5 0.00015994655 3.5554945f-5 9.201267f-5 -4.8966453f-5 -0.00013019545; 0.00010356599 9.386258f-5 -9.555804f-5 -0.00011302738 -2.982038f-5 -2.6537766f-5 0.0001326902 1.3503379f-5 0.0002321677 -0.00014570077 -6.160482f-5 9.383279f-5 -0.00012899985 0.00026124864 7.559704f-5 -3.245845f-5 2.2868871f-5 7.393422f-5 6.939985f-5 -8.073464f-5 -6.493277f-5 -7.971229f-5 -3.690797f-5 7.159614f-5 0.0002240174 -4.9012317f-5 -6.309731f-5 6.5474334f-5 -3.9521747f-5 2.5880536f-5 9.168814f-6 7.0612776f-5; 3.6504764f-5 -0.00014354469 -0.000107286105 0.00016915759 6.8394984f-5 -3.302745f-5 -0.00015622316 0.00012367406 4.196102f-5 -0.00024222277 2.7071996f-5 -0.00016572379 -1.2416313f-6 9.093682f-5 0.00013993763 -2.9302382f-5 -4.238206f-5 6.964645f-6 0.00017332078 -0.0001184762 2.8117662f-5 -0.00012941372 -2.5065112f-5 1.2464733f-5 -0.00010473354 0.0001495244 0.00016974339 -1.9494899f-5 -3.4896857f-5 -5.3749554f-5 -3.168685f-5 0.00017349978; -5.132354f-6 -4.424082f-5 -4.964353f-5 -1.5468098f-5 -2.581242f-5 3.4385128f-5 -0.00016474552 3.214969f-5 -2.6949356f-6 -7.3306846f-6 -0.00023628463 -2.713697f-5 4.682057f-5 -5.0788974f-5 -6.5953805f-5 -2.250798f-5 -0.00015827961 -8.265855f-6 0.00013737188 -6.6480216f-6 3.0908624f-5 2.7139635f-5 0.00023415721 -8.188687f-5 -6.4403386f-5 -8.165859f-5 -9.382276f-5 -2.3535324f-5 6.069784f-5 -7.942674f-5 0.00018662795 2.8620603f-5; -0.000110345616 4.7111524f-5 4.532873f-5 -5.5340144f-5 -0.00020332819 -1.6361653f-5 -8.1452934f-5 0.00020983846 -1.7723463f-5 6.42776f-6 -0.00011324887 -4.4252903f-7 -8.6307926f-5 -6.0750586f-5 3.6340498f-5 -0.0001346239 1.4279817f-5 -0.00015716022 -0.00011565898 0.00010061201 -4.9437135f-6 3.4787143f-5 -4.7063295f-5 -0.00018093284 1.5014159f-5 -4.8186812f-5 0.00010728006 -4.107313f-6 8.443386f-6 0.00018245523 0.00015710696 0.00010806732; -4.593428f-5 3.7794541f-6 0.0001901828 -0.00015562437 0.00010617965 2.7162823f-5 -0.0001961611 -6.838603f-5 0.00010460149 1.8424496f-5 -1.7457225f-5 0.00010073535 -1.3157878f-6 4.764597f-5 -7.2148105f-5 -6.362214f-5 -8.791657f-5 2.551204f-5 -7.157915f-5 7.802277f-5 -6.758359f-5 0.00011234907 -2.9502333f-5 4.234131f-5 1.3768293f-5 0.00015046699 -2.074506f-5 1.5611389f-6 -8.091321f-5 -2.2058419f-5 -9.851729f-5 4.396451f-5; -7.404357f-5 -0.000116137926 -3.919746f-5 6.9903613f-6 0.00015672618 -0.00012052637 -5.075656f-5 -4.31624f-5 -9.2455695f-5 3.0230767f-6 7.4591815f-5 -0.00017187528 8.2586484f-5 1.3629483f-5 4.858484f-5 -1.0701159f-5 -9.76107f-5 8.8346635f-5 5.066664f-6 0.00011652417 -0.00016424122 6.7113804f-5 6.106476f-5 9.281236f-5 3.989169f-5 -0.00013496888 9.74449f-5 -0.000106641215 0.0002279588 0.00017059782 -7.062001f-5 0.00013929975; -4.043872f-5 -4.6237346f-5 -1.505656f-5 -8.941849f-5 -7.745332f-5 -1.8620416f-5 -6.3493855f-5 8.04013f-6 4.0744246f-5 5.7554855f-5 1.9408004f-5 4.646881f-5 0.00020582223 -0.00015895322 0.00014299557 0.00029195202 -6.145626f-5 -5.6737663f-5 -0.00015056907 -8.007406f-5 3.0197156f-5 -0.000121206926 -0.00016406784 -0.0001453998 -7.0228336f-5 -0.00016744592 -2.524225f-5 -3.979123f-5 -5.5559358f-5 3.8262624f-5 1.5647951f-5 0.00022073666; 5.5627712f-5 6.8776084f-5 -7.352222f-5 1.0002345f-6 9.808353f-5 -0.000106885396 -7.6797754f-5 3.6088128f-5 -3.8107824f-5 -0.00020009115 -3.370073f-5 7.5047414f-5 -0.00011180497 3.9220264f-5 3.158692f-5 1.054442f-6 -0.00023945759 -0.00012433676 -2.6348956f-5 -5.522512f-5 -0.00022014593 -0.000115676325 -3.7359912f-6 -7.85893f-5 -7.1170274f-5 -3.0461179f-5 8.644194f-5 -2.413197f-5 -5.5292185f-5 0.00013858915 5.4612905f-5 -1.0997543f-5; 2.031275f-5 6.189822f-5 -0.00016143479 -8.91556f-5 1.1044021f-5 1.8451106f-5 7.726851f-6 0.00010280489 0.00027500797 -0.00015673127 8.405811f-6 -0.00011637053 7.788022f-5 8.792073f-5 -2.3665384f-6 -0.00023994221 9.9212775f-5 5.095716f-5 -0.00016145442 7.010363f-5 -9.213432f-5 5.1421102f-5 0.000120278026 -0.00019157163 3.3669326f-6 2.5090276f-5 -0.00010461801 8.769612f-6 8.415159f-5 0.00012264286 -6.644246f-5 -2.309465f-6; -8.704746f-5 8.5311745f-5 4.154299f-5 0.0001244581 0.0002487938 -5.0771534f-5 2.9506034f-5 7.959508f-5 -6.330686f-5 8.9787296f-5 0.000145846 -1.31013385f-5 -4.477698f-5 -0.00015706575 -2.0324507f-5 2.814357f-5 7.020316f-5 1.8658266f-5 3.0784297f-5 0.00023089771 -0.00014576453 -6.0703285f-5 0.00017205071 -4.7142464f-7 -3.542533f-5 -8.493326f-5 7.959889f-5 0.00012569474 -0.00010559724 -0.00021669785 0.00016136313 0.0001259058; -1.6658827f-5 -3.0758212f-5 3.1607575f-5 5.1340565f-5 -0.00016384062 -5.1519546f-5 -2.1732601f-5 -6.084497f-5 3.4906643f-5 5.818177f-5 -4.396652f-5 0.00018607815 -8.900506f-5 0.000112422706 -7.666776f-5 3.1738666f-5 1.0709513f-5 -9.913219f-5 6.0799746f-5 -1.1407148f-5 -0.00013092186 -1.05260415f-5 1.4557312f-5 -5.050489f-5 -0.00010791435 -7.1867944f-5 -0.0002690436 -3.3812503f-5 -2.2607754f-5 -1.2320319f-5 -0.0001482694 2.9437142f-5; 6.869186f-5 -1.277427f-8 -6.626843f-5 -2.1587384f-5 7.173915f-5 -3.0154533f-5 3.0754265f-5 -5.243435f-5 -0.00017200179 4.5476485f-5 -9.203399f-5 -0.00013684381 7.4456216f-6 2.2765864f-6 5.940914f-6 -0.000103125516 -0.00010338021 7.6328126f-7 -4.4208286f-6 3.710244f-5 -1.5903246f-5 -0.00020041747 5.2565396f-5 1.0064219f-5 0.00013255388 0.00018796537 -1.46149805f-5 4.7379745f-5 -5.66258f-5 7.8980884f-5 0.00012663766 4.566003f-5; 6.0409773f-5 0.00010700812 -0.0001300802 0.00013264797 -9.973728f-5 0.00018789402 6.6095f-5 0.00029767348 -0.00012588724 -3.6550168f-6 2.0192834f-5 9.936858f-6 4.4869426f-5 3.8916845f-5 -7.401763f-5 -0.00018950703 -0.00011452327 -3.1347794f-5 7.583299f-5 -1.9396692f-5 -0.00015249474 -3.6658774f-5 1.4171241f-5 0.00012923034 0.00012168421 -9.645912f-5 -8.4772786f-5 -0.00020042581 3.073528f-5 6.52481f-5 2.0169764f-5 -8.7430744f-5; -0.00015594509 -6.431986f-5 8.600898f-5 -6.847904f-5 1.2529786f-5 -8.806717f-6 3.4579807f-5 -2.468049f-5 -8.22499f-5 -9.3672985f-5 -9.418263f-6 -5.1334155f-5 0.0001442151 -4.4076656f-5 8.811325f-5 7.1457915f-5 -5.6878343f-5 0.00011270518 5.951645f-5 0.0001358126 -7.364864f-5 -0.00011210586 8.181958f-5 6.882281f-5 -5.7000256f-5 7.8650046f-5 1.134115f-5 4.324312f-5 -0.00015913602 -1.3885723f-5 3.2554562f-5 -0.00013899442], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[4.720289f-5 -6.1033745f-5 -0.000117455354 6.339205f-5 9.4987015f-5 0.00010637434 -0.00014580597 -7.928112f-5 0.00012856306 5.526573f-5 5.67344f-5 -3.319828f-5 1.4563848f-5 -1.8394529f-5 -6.5796544f-6 -9.174684f-5 3.709855f-5 5.6026445f-5 -5.2116233f-5 -3.3011038f-5 -0.0002144286 -6.186036f-5 0.00013419699 -3.4802364f-5 6.544735f-5 7.683924f-5 4.327698f-5 0.00013636211 -3.126764f-5 -4.0090868f-6 -0.0001071095 -8.382162f-5; 2.4680015f-5 -9.9803234f-5 3.5336463f-5 0.00012122493 2.7823971f-5 0.00014654118 -3.1051786f-5 7.1056456f-6 9.648971f-5 8.8834306f-5 -8.428908f-5 2.0494661f-5 -0.00013468543 -0.00011850743 0.00015842692 -0.0001187949 0.00015368695 -3.531974f-8 2.317611f-5 -5.6460435f-6 7.135657f-5 -3.1484447f-5 0.000107748965 -8.069987f-5 -4.6652836f-5 -0.00012347964 -0.00013556932 -2.4259138f-5 -9.283368f-5 -0.00015578445 7.999232f-5 -6.526465f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most DL frameworks, Lux defaults to using Float32, however, in this case we need Float64

julia
const params = ComponentArray(f64(ps))

const nn_model = StatefulLuxLayer{true}(nn, nothing, st)
StatefulLuxLayer{true}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),  # 64 parameters
        layer_3 = Dense(32 => 32, cos),  # 1_056 parameters
        layer_4 = Dense(32 => 2),       # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"];
        position=:lb,
    )

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end
loss (generic function with 1 method)

Warmup the loss function

julia
loss(params)
0.0007120370537351161

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end
callback (generic function with 1 method)

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob,
    BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback,
    maxiters=1000,
)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [7.860438927306752e-6; -5.2525432693069914e-5; 0.0001316678535656434; 7.766766066189403e-5; -0.00014561972056968258; -0.00017803735681794764; -2.7755737391948064e-5; 0.00010047880641644231; 6.742808636772987e-6; -0.0001594555069458531; -9.549310925642917e-5; 4.0078411984731166e-5; -9.890780347622999e-5; 0.00018193703726856453; 8.555898239135396e-5; -4.970337977284799e-5; -3.876525443043724e-5; -1.3489130651576472e-5; -0.0001316449342992339; 4.2976120312192544e-5; 1.1798018931580595e-5; -0.00011137081310125814; 3.624382225096552e-5; -9.186421084440545e-5; -6.50829315417017e-5; 4.663854997486031e-5; -1.4155018106957071e-5; 0.00016139025683481036; -0.00012499559670698556; -8.825054828772068e-5; 6.8640338213228945e-6; -0.00011362308578100378;;], bias = [3.1690976299135775e-18, -4.722969324672089e-18, 2.7588708788696755e-16, 8.7796646292919e-17, -1.4630496965090196e-16, -1.6769321880840118e-16, -5.820014156550568e-17, 6.977471489405705e-17, -2.890529698946684e-19, -1.2224450295050418e-16, -7.643036073865617e-17, -9.501276662725863e-18, -1.0611435143202708e-16, 9.369849162608388e-17, 2.5226215178435307e-17, 1.3438622433120839e-17, -4.123813299284313e-17, -2.6249650815877978e-18, -1.913405972317005e-16, 4.5660264392232926e-17, 5.99414400413261e-18, -3.074086997943475e-16, 1.7365310931321055e-17, -1.1415485472436397e-16, -1.5490152008516075e-16, 2.986866709460698e-17, -2.160552188993909e-17, 3.833659073928962e-16, 1.337248393465031e-16, -4.397377880384665e-17, 7.926768205002992e-18, 3.816684556720878e-18]), layer_3 = (weight = [-0.00016437806084832564 -3.001764094188149e-5 0.00021092195924757385 -1.9296682925416568e-5 6.399878725097083e-5 3.551704025631777e-5 -7.671381650107335e-5 7.621048322107302e-6 -1.7719781002601727e-5 -0.00013926814904515887 -1.9198150088414602e-5 -0.00014257945190729893 4.098224848894425e-5 7.90167202885735e-5 -0.00010902282270217395 -8.94187662688671e-5 8.630779006670836e-5 0.0001161503630519248 0.0001343525025240552 -7.870364538939978e-5 -8.417182585072153e-5 -6.104929093817117e-5 -0.00011003546137908805 -1.0305558365872467e-5 0.00010379545232899278 0.00011832447011874692 0.00018107777099975937 -5.199196425928032e-5 -0.00013425837574140834 -1.352174627209143e-5 -7.01890141990633e-5 -8.154209104213711e-5; -0.00011381601295711315 0.00013198791778136162 6.216576214369268e-5 0.00010475758092666106 -0.0001225899811046961 -3.5810686016207306e-5 -5.9808125728022707e-5 4.61885632162634e-5 0.00010174078700088122 8.599474160887054e-5 2.4666427931650673e-5 0.00016514476221941542 0.00011993705477520825 0.0001714183111530878 0.0001857003611135487 8.354861559068904e-5 -0.0001194574849009314 -0.00010220032497586527 -2.5700453691981974e-5 -3.5517286264167534e-6 -0.00014269199680115872 -1.3330677278147003e-5 9.086804608659799e-6 -3.1402594389375056e-5 -0.00010148087828697545 -8.16447188057738e-5 -7.184443318096766e-5 9.113527114690955e-5 -5.4318412070144876e-5 -0.00014002100731290844 -3.2324927518356876e-5 -6.126414676557068e-6; 9.971092314449045e-6 5.168327419720576e-5 0.00013238534387833572 -6.0915782395239954e-5 -2.8461962809749666e-5 -5.0006739055474116e-5 -1.2506177025566816e-5 -6.465337270005083e-5 1.501064790649838e-6 -0.00013679192403164993 2.4533460365418516e-5 0.00012220918955914726 0.00018069099400359748 4.203949805864354e-5 -3.459101716216215e-6 -0.00023079007865693416 5.7679132570559695e-5 0.00012361285275016106 9.097247761142576e-5 0.00011609813458686727 -0.00015204243305437138 0.00011802418970670021 -4.590880331182789e-5 3.223340258169601e-5 6.240068700332101e-5 -3.086699421931972e-5 -0.0001495889510430383 9.025939738544304e-5 -4.874800202619058e-5 3.569537233612217e-5 4.918090866464321e-5 0.00010072208619145118; -2.0699493046023594e-6 -3.182671328369754e-5 0.00013867658587576022 0.00013392925636315366 8.841231344438132e-5 6.331661252869799e-6 -4.0251057115999584e-5 -0.00011997674421702097 -0.00015827248837860415 8.382575349121165e-5 3.636058924774163e-5 -9.078734849874109e-5 -4.752221623614002e-5 0.0001908246837689771 0.00011561791931915623 7.919080114396076e-5 -6.270111892201181e-5 -0.00022068713708842295 0.00016180005938683768 2.8038740011363268e-5 9.487767844864978e-5 5.138250587213179e-5 4.4211671293562466e-5 7.896700723966134e-5 -6.684446391711295e-5 5.966341147603265e-5 -7.46167073242006e-5 1.0226228069469978e-5 -4.566273976645236e-5 -2.3574725973340262e-5 -7.383362511001644e-5 -3.5561422309295024e-5; -0.00011671944717607035 0.00016689853022610273 -0.00010007522847233795 8.615681385192531e-5 0.0001001190854752392 2.8148765197624996e-5 5.4008835506760133e-5 7.37577235143584e-5 6.937769161806867e-5 5.928271151999513e-5 1.8332183246752113e-5 -0.00015216535153529177 0.00023121651124013114 4.359079561465944e-5 -0.000197781167442077 -3.143419358609743e-5 -3.1108780019781914e-5 0.0001434469692342864 -0.0001180122611528308 2.2652585032414297e-5 -1.200797065182629e-5 0.00013235861355708443 -2.9821845016773224e-5 -0.0002715281783436473 -5.260462545054527e-5 -7.438724911269299e-5 -0.00022444069814598694 0.00011851856346977445 1.0914962915489514e-5 7.141478512705753e-5 7.474539110473782e-5 -0.00011599668265050746; 0.00018576495818151595 9.744205777124159e-5 -2.086707213137709e-5 -0.00010667544524594361 2.5710317962796483e-5 -1.9458564971573887e-5 -1.2228990677704538e-5 7.855108882636743e-6 0.00014285045795431792 1.5196160851624561e-5 0.0001053075061953723 -9.20522911165789e-5 5.876763275764955e-5 -0.00027841765081218847 1.1446125021306062e-5 -5.062447639475007e-5 3.1673837874223278e-6 -5.100625316672525e-5 -5.075265330205151e-5 0.00010642150623780278 -3.5141992576435353e-6 -9.173068470857155e-6 -5.3013362454385145e-5 -9.310929677754246e-6 -0.00012336972981401976 -0.00015029513856106813 -0.00016648877176167004 -3.758297176394942e-6 0.00010693653489753093 -4.192414457620167e-5 -6.603108017416245e-5 -1.792086045142471e-5; -0.00010453094441133783 0.00015391301272513888 9.915573460358847e-5 8.349125997240086e-5 -7.905319657100862e-5 5.650681410525768e-5 0.00012846153836774366 1.7699755660324737e-5 1.3950843055979353e-5 -9.806388410896795e-5 -6.420223220873438e-5 0.00022170583102463424 4.785150227446258e-5 -5.2368232757265255e-5 4.902166543372262e-5 -6.746189032371526e-5 0.0001488928348019967 5.705942308607271e-5 -0.0001069122925788937 0.0001795928211843841 -4.0311530795109046e-5 8.081127959776209e-6 7.2683998562993805e-6 -5.68525987789858e-6 6.090972525028043e-5 -5.126530308432299e-5 8.384444223095423e-5 0.0001636760527229046 -5.518766997073027e-5 -0.00014552007345871738 -0.00010544330584041074 -0.0001541833088073696; 0.0002293175000499449 9.020059383823668e-5 3.004003315933052e-5 -5.2223798995832364e-5 0.00010644477311640511 -5.005348997885512e-5 3.625755358138017e-5 -3.728389332768434e-5 2.4972765238488035e-5 -0.00023992275804860482 -0.00021005246366489718 -2.6651678347261686e-5 -0.00015839592946808743 4.519060326627472e-5 -0.00010323228183783848 -9.640088523388743e-5 0.00012595699241268427 1.3368331344680036e-5 -0.00015947691849081775 -6.847828466794432e-5 9.855535768787195e-5 -3.99286129709671e-5 -3.3914153611986386e-5 3.4673966872513915e-5 -5.4852769828864104e-5 -4.344157263636527e-5 -0.00010101127394824462 -0.00013983816301563646 0.00018996997983545798 -0.00026680346331209136 -0.00013637621872709328 5.1322311044218384e-5; -7.251044679850486e-5 -2.214856375775544e-5 -5.078625271398424e-5 8.178020708310985e-5 -0.00014436220170882525 6.0708298960706135e-5 -4.042009039966561e-5 0.00011025793411118394 0.0001528719882206232 6.862619976327607e-5 -7.154334106434613e-5 7.953337681993639e-5 0.00016608274073387928 0.00014572468408896152 0.00010068601169205792 -5.7808181181999396e-5 -0.00022931031652678 -0.00020820581939023136 -6.14408267121654e-5 -0.00012487667790100188 -5.339676444254032e-5 9.256228860049838e-5 1.6039137890229888e-5 0.00011951843826846697 4.632568799007385e-5 -0.00027295030248768854 7.226524086018874e-5 8.912026588386682e-5 -6.390016689651213e-7 0.00014420007993049493 0.00012912231669343866 1.956820106880121e-5; 6.820772991093462e-5 2.1254050935816552e-5 0.00015197904589627165 0.0001277667544320766 -4.542333395761709e-5 9.459615115814153e-5 2.8411437725732396e-5 3.679449959350978e-5 -9.117529394527309e-5 -8.758919999196077e-5 -4.2329960577931454e-5 -1.0144543792852978e-5 -2.800941534472828e-5 0.00019726418408138737 4.0235158068087194e-5 -2.032383404978274e-5 0.0001994509440385404 -0.00011162303404373478 -0.00013777382979128283 -0.00011793587314853088 -6.210373484211795e-6 0.00020680606407155213 -1.3990936039339685e-5 8.547127812677502e-5 0.00010467948608608965 3.4930981329375304e-5 -8.790327397830215e-5 0.00010486142868216187 7.141122490907534e-5 0.00011895331421443785 0.00013858456090136087 -2.8784597687936806e-5; -1.5860273009870596e-6 4.201089126071218e-7 0.00021171998620556242 4.522397981048049e-6 6.012802865879095e-5 9.99254795402399e-5 -3.561636306986259e-5 -6.138764486823493e-5 -0.00017997176231293035 6.246032232010124e-5 -1.3497106314760225e-5 -1.5011047557350911e-5 -0.000127839118553682 0.00011435324510383563 0.00019095278791209487 2.190836372901134e-5 -4.1166452624025464e-5 -1.1743301305983084e-5 7.070358211930417e-5 -9.540511140076516e-5 7.020535364571925e-5 0.000139547332460635 0.00010262069796735362 -0.00012118509524444342 -7.08920755045315e-5 -0.00016761459604313673 -1.7645824604556307e-5 -3.452961876465634e-5 -2.2219289653011745e-5 -2.1817491264671302e-5 -4.4123703750670265e-5 0.00011155886265040982; 0.00016322664453061945 0.00014593895468740305 3.38868136389157e-5 0.0001275027840386179 8.392109305010668e-5 0.00018751060417731248 8.67805613259002e-6 0.0002031460986693403 2.086227652787417e-5 1.0925697665026366e-5 -0.000148922123458514 -3.377056998554046e-5 0.00013773267858075127 -6.447396151610769e-5 1.8885938282181468e-6 -1.5252040044038412e-5 -0.00018388509340704775 -0.00020160053679839997 7.517021879199652e-5 -8.784191447896325e-5 -9.126163638550143e-5 -1.8795023039029765e-5 -1.5422080082975986e-5 2.640002060095067e-5 -5.737931712767606e-5 -4.767868973181286e-6 0.0001957927103142318 0.00010669599725975899 -2.501945232991478e-6 -5.6286028098591877e-5 -0.00012336296492842307 -1.8163820983073793e-5; 4.387362079610815e-5 3.081015913080804e-5 0.0001817471895428147 -2.8541160447322837e-5 -5.4667372553077835e-5 3.123232110548518e-5 5.7793520880254545e-5 -3.168949822980465e-5 -5.97186843337489e-6 7.561500043986861e-5 4.519134037577175e-5 6.334726293019989e-6 -5.920338282573348e-6 -0.00021513340814825556 1.7109627349733646e-5 -2.608024058831441e-5 6.723831318553467e-6 9.437682839957751e-5 9.714654535667561e-5 3.58741383187894e-5 2.9910119535957843e-5 9.456149947976864e-5 0.00013769551534819968 7.591111736287072e-5 0.0001212724275482746 -4.2953313624863676e-5 3.141121143740353e-5 0.0001629726723130462 -2.362060493584466e-5 5.085196264003189e-5 2.0586541921107583e-5 -9.444022342182586e-5; -9.685758408346632e-5 4.364034981512374e-5 -8.146627072896034e-5 0.0001229851696388734 -0.0001297638881970072 5.063172977297584e-6 -0.00020389505510871625 -8.498346863966554e-5 -3.220260254062204e-5 5.7336656083836206e-5 5.632566177332536e-5 0.00012570519908143912 -6.903969909949384e-5 9.563736972448336e-5 0.00013406571093759972 -4.0563754174032545e-6 -0.00014827228000109327 -9.67489394843712e-5 -0.00012002842407859088 -1.9552446190628368e-5 -6.245657272126004e-5 -0.00010350193764502016 7.2529501830457144e-6 5.4166375832163306e-5 -9.729999141023401e-5 -0.00011858525243775148 -4.4187214314718115e-5 3.9051799887576886e-5 7.176198420161946e-5 2.2048443553965248e-5 -9.916636184976858e-5 -0.00018513494427501925; 2.090883834711751e-5 -4.86838190354821e-5 -8.092724339262777e-5 8.123970736591713e-5 3.441147681400975e-5 -0.00010688578651211203 -0.00017534540994716325 -9.553687879274902e-5 5.324418715276299e-5 1.4815194741412272e-5 -2.5162886359376654e-6 2.1040982470317132e-5 -0.00013107203732402376 0.0001517277149886922 9.039843542809253e-5 1.0894737793882831e-5 -3.224914427306006e-5 -0.0001342116857941215 5.293620129984716e-6 3.59966642335573e-5 6.928717890281797e-6 2.8395233026409616e-5 2.974264942708453e-5 -0.00014179326273192986 6.431126065394449e-5 -0.0001352993832497807 0.0001949491005561362 1.6889708500410728e-5 -0.00019581432308227978 -0.00010042196401086974 0.00013518662570817869 0.00026510344731463784; 0.00011439415779256173 -0.00017873858938167853 -8.156019371875417e-5 -2.3657684648637738e-5 5.2571950293501925e-5 1.871661756246687e-5 -8.481401651573446e-5 -0.00020245902611116832 -6.970817535824231e-5 0.00015375887837087057 -0.0002252806492446639 1.1744980513471436e-5 2.5852605761033913e-5 -1.4503951728611104e-6 0.00012424922410224576 2.608449351614134e-6 -3.924265307199279e-5 -1.4089212852409973e-5 -0.00012934077436278432 -4.0909058368411555e-5 -0.00010972683717903768 7.257891651767383e-5 0.0002641473207791608 8.744857625575049e-6 3.3685085002220845e-5 9.373047454806915e-5 -2.3758658569418754e-5 3.0595848716693464e-6 -9.055626774846455e-5 0.00014242431883991713 -6.926328167024485e-6 9.393039603543552e-5; 0.0001841201183237734 -0.0001094061926737797 0.00015788291330359793 9.782154211686514e-5 0.0001258106231075092 -3.7072783921362604e-5 5.9204106041289164e-5 5.7525986453098875e-5 6.91411578413698e-5 -1.6425175086220383e-5 -8.042771241483175e-5 -6.61999418316991e-7 6.40493190116716e-5 0.0001124917498640867 -8.093995438278304e-5 1.5608799917097923e-5 -0.0001614524635401907 -0.00020957387186674022 0.00012761543894562787 0.00012393991619726798 8.609437918901831e-5 0.00027475403021932237 -9.615978796820858e-5 -9.75478660580217e-5 9.326350391754714e-5 8.283804618666058e-5 0.00016257405787153546 5.094758944728082e-5 -0.00016081130615524798 0.0001304944189582514 2.5020560319246074e-5 9.768924432558082e-6; 4.450085401528978e-5 -0.00011666123970301385 -5.0584801614812454e-5 -0.00011150179267517818 -0.00020778554386581975 4.6902000063500914e-5 2.211523536586659e-5 -2.424625887599876e-5 -7.743747632254398e-6 1.7532313391503186e-5 -9.82365868951343e-5 -4.6612321570343514e-5 7.069274800939851e-5 -2.470392206690141e-5 8.2366476304535e-5 -1.931315779459223e-5 7.0330851811134024e-6 6.528843590401575e-5 -6.184371665242768e-5 5.0172015433174444e-6 0.00016198599990800488 -0.00014750880433369088 -3.796728339168724e-5 7.662514184246219e-5 0.0001308055965962676 0.00010756193154136161 8.559682686225042e-5 0.00015994685457896646 3.555525330354768e-5 9.201297728920743e-5 -4.8966144790235e-5 -0.00013019513932179325; 0.00010356828599895037 9.386487974977538e-5 -9.555574185955084e-5 -0.00011302508326055312 -2.9818081452814296e-5 -2.6535467691564783e-5 0.00013269249641296085 1.3505676838274635e-5 0.00023217000009451678 -0.0001456984766949469 -6.160252203274579e-5 9.383508470334614e-5 -0.00012899754917755833 0.00026125094038857146 7.559933758299872e-5 -3.2456152689815777e-5 2.2871169425957352e-5 7.393651933368904e-5 6.940214982447468e-5 -8.073233882471435e-5 -6.493047228589041e-5 -7.970999402035335e-5 -3.690567361929861e-5 7.15984412860683e-5 0.00022401970520574907 -4.9010018551894524e-5 -6.30950101143061e-5 6.547663244842151e-5 -3.951944861486171e-5 2.588283380161544e-5 9.171112401123766e-6 7.061507378663521e-5; 3.650465740143488e-5 -0.00014354479637326503 -0.00010728621120428441 0.0001691574828154053 6.839487767216057e-5 -3.302755496328422e-5 -0.00015622326893130153 0.00012367395836996112 4.196091256822626e-5 -0.00024222287603983404 2.7071889490668772e-5 -0.0001657238942417672 -1.2417374083659221e-6 9.093671113156062e-5 0.00013993752807514928 -2.930248840051995e-5 -4.23821663563495e-5 6.964538698281893e-6 0.00017332066934691966 -0.00011847630659680641 2.811755555818023e-5 -0.0001294138243967088 -2.5065218068578445e-5 1.2464626962225503e-5 -0.0001047336452021618 0.00014952430072376674 0.00016974328471483824 -1.9495005034643464e-5 -3.489696311511811e-5 -5.3749660509440975e-5 -3.16869561348932e-5 0.00017349967245614364; -5.134221679304932e-6 -4.4242687577910665e-5 -4.964539897261032e-5 -1.5469965943313e-5 -2.5814287651847827e-5 3.438326035761383e-5 -0.00016474738862232237 3.214782060607047e-5 -2.6968031633945474e-6 -7.332552227997665e-6 -0.00023628649563394832 -2.7138837536744497e-5 4.681870137940785e-5 -5.0790842065865095e-5 -6.595567304973015e-5 -2.250984783674882e-5 -0.00015828147792124072 -8.267722874711865e-6 0.00013737001660414194 -6.649889142840209e-6 3.090675687783429e-5 2.7137767177949126e-5 0.0002341553432088635 -8.188873965625111e-5 -6.440525379364097e-5 -8.166045648610994e-5 -9.382462798018906e-5 -2.3537191223984183e-5 6.069597162789717e-5 -7.942860741965902e-5 0.0001866260814167841 2.8618735599623768e-5; -0.0001103470353325598 4.711010438881455e-5 4.53273092364272e-5 -5.5341563382858625e-5 -0.00020332961179405889 -1.6363072462575026e-5 -8.145435316578936e-5 0.00020983703831120378 -1.772488252641652e-5 6.426340916054237e-6 -0.00011325028793976707 -4.439482393152609e-7 -8.63093450232133e-5 -6.075200548251977e-5 3.6339078689020515e-5 -0.0001346253197323289 1.4278397940181219e-5 -0.00015716163801585804 -0.00011566039796775824 0.0001006105881617644 -4.945132731116753e-6 3.4785723584034426e-5 -4.706471412533389e-5 -0.0001809342579133539 1.5012739423803873e-5 -4.818823112032762e-5 0.00010727863951728201 -4.108732299523927e-6 8.441966712097673e-6 0.00018245381175915584 0.00015710553958455547 0.00010806589813114257; -4.593441173670279e-5 3.7793233154343817e-6 0.00019018267647010885 -0.00015562449681925028 0.00010617951937133138 2.7162692440875975e-5 -0.00019616122865004856 -6.838615792690198e-5 0.0001046013568888565 1.842436557166921e-5 -1.7457355915716704e-5 0.00010073521860289928 -1.3159186160506811e-6 4.764584098706697e-5 -7.214823546706135e-5 -6.362226833476142e-5 -8.791669808447752e-5 2.5511908625453473e-5 -7.15792774095046e-5 7.802264024616761e-5 -6.758372166546458e-5 0.00011234894207559183 -2.9502463639396427e-5 4.234118039716605e-5 1.3768161971206715e-5 0.00015046686183308797 -2.0745190176562644e-5 1.5610080985844744e-6 -8.091334149826649e-5 -2.2058549638636323e-5 -9.851741908237734e-5 4.3964377419293234e-5; -7.40427903563168e-5 -0.00011613714464124472 -3.919667956302493e-5 6.991142572812671e-6 0.00015672696013120283 -0.00012052558754454351 -5.07557787839243e-5 -4.316161807742336e-5 -9.245491396518537e-5 3.0238580168242475e-6 7.45925963658433e-5 -0.00017187449425341714 8.258726579719799e-5 1.3630264266102851e-5 4.8585620834498505e-5 -1.070037792465123e-5 -9.760991538291944e-5 8.834741588566336e-5 5.067445317985583e-6 0.00011652495145449258 -0.00016424044310929368 6.711458540871316e-5 6.106554342318183e-5 9.281314401073058e-5 3.989247038611925e-5 -0.000134968101474216 9.744568074006471e-5 -0.00010664043379598136 0.00022795957470429556 0.00017059860242647743 -7.061922674415746e-5 0.00013930053232820265; -4.044082957334265e-5 -4.6239454925667574e-5 -1.5058668716651286e-5 -8.942060080952913e-5 -7.745542675618877e-5 -1.8622524615033393e-5 -6.349596365252841e-5 8.038021084948716e-6 4.0742137250346976e-5 5.7552746190796185e-5 1.9405895406825072e-5 4.646670063012006e-5 0.00020582012225278006 -0.00015895532701758647 0.00014299346516709894 0.0002919499107949414 -6.14583681024639e-5 -5.6739771346472196e-5 -0.00015057118285111284 -8.007616848116997e-5 3.019504733602143e-5 -0.00012120903407120613 -0.00016406994687058322 -0.00014540190600454618 -7.023044450106193e-5 -0.00016744802847169573 -2.5244359087808026e-5 -3.9793338467259885e-5 -5.5561466390632696e-5 3.826051547294631e-5 1.56458424018658e-5 0.0002207345547933181; 5.562366295291923e-5 6.877203504357721e-5 -7.352626908300165e-5 9.961856414851366e-7 9.807948317416799e-5 -0.00010688944512705955 -7.680180333752213e-5 3.608407878004278e-5 -3.811187325825106e-5 -0.0002000951970364765 -3.370477843760378e-5 7.504336481018075e-5 -0.0001118090183842998 3.9216214823932526e-5 3.158286944766428e-5 1.0503931173218698e-6 -0.00023946163474081018 -0.0001243408076368469 -2.6353005345069815e-5 -5.522916745009717e-5 -0.00022014997721614426 -0.00011568037353186304 -3.740040077691801e-6 -7.859334782893724e-5 -7.117432306019161e-5 -3.0465227431495027e-5 8.643789278378611e-5 -2.413601924102731e-5 -5.5296233589387884e-5 0.00013858510301019026 5.4608856040638585e-5 -1.1001591523157642e-5; 2.0312441915057575e-5 6.18979157244491e-5 -0.0001614350930375328 -8.915590472121852e-5 1.1043714106683836e-5 1.845079902899112e-5 7.726544094230562e-6 0.00010280457906822585 0.000275007665617641 -0.00015673157934167355 8.405503716977726e-6 -0.00011637083837364973 7.787991314592678e-5 8.792042178732485e-5 -2.3668457523570387e-6 -0.00023994251562350593 9.92124678979331e-5 5.095685089520924e-5 -0.00016145472357117588 7.010332599214107e-5 -9.21346308367625e-5 5.142079505652123e-5 0.00012027771851911034 -0.0001915719348524987 3.3666252185238254e-6 2.508996847075708e-5 -0.00010461831873350964 8.769304326172683e-6 8.415127929232049e-5 0.00012264255026287287 -6.644276611575116e-5 -2.309772231336127e-6; -8.704423278478563e-5 8.531497398715528e-5 4.1546218128375996e-5 0.00012446133580149802 0.0002487970322565734 -5.0768305370865204e-5 2.95092631152206e-5 7.9598305021387e-5 -6.330363437674596e-5 8.979052465498004e-5 0.0001458492314871154 -1.3098109846136582e-5 -4.4773753031880795e-5 -0.00015706251913593465 -2.032127856884567e-5 2.8146798215206692e-5 7.020638979364379e-5 1.8661494252489454e-5 3.078752517673589e-5 0.00023090094138959578 -0.00014576129804273736 -6.070005663580276e-5 0.00017205394208057065 -4.681960060427926e-7 -3.542210295355973e-5 -8.49300286770041e-5 7.960211762318715e-5 0.00012569797210949896 -0.00010559401023577317 -0.00021669462337533475 0.00016136636070067408 0.0001259090330879849; -1.6662365804252582e-5 -3.076175102981278e-5 3.1604035986664635e-5 5.1337026026862645e-5 -0.00016384416353321847 -5.152308490141233e-5 -2.1736140250170154e-5 -6.0848508790532144e-5 3.4903104326046616e-5 5.817823062170568e-5 -4.397005882130676e-5 0.000186074609748147 -8.90085966130033e-5 0.00011241916677318738 -7.667130098345636e-5 3.1735126914982294e-5 1.070597413851729e-5 -9.913572552980692e-5 6.0796207482770946e-5 -1.1410686918707744e-5 -0.00013092540298024035 -1.0529580275049113e-5 1.45537733733078e-5 -5.050842714628955e-5 -0.00010791788640565542 -7.187148271262255e-5 -0.00026904714736522935 -3.381604160313152e-5 -2.2611292431987064e-5 -1.2323857798058405e-5 -0.00014827294485078978 2.9433603278433405e-5; 6.869139095382899e-5 -1.324076598870732e-8 -6.626889891609154e-5 -2.1587850749853613e-5 7.173868569393166e-5 -3.0154999802246147e-5 3.075379852984246e-5 -5.2434815114025904e-5 -0.00017200225640240922 4.5476018460109414e-5 -9.20344545358206e-5 -0.00013684427623789598 7.445155080716845e-6 2.2761199323646083e-6 5.940447483083813e-6 -0.00010312598253054374 -0.00010338067742682837 7.628147676242766e-7 -4.421295141723941e-6 3.710197332278986e-5 -1.5903712590586907e-5 -0.00020041794133990686 5.2564929393922574e-5 1.0063752671925168e-5 0.00013255340953204978 0.00018796490538247097 -1.4615446996829499e-5 4.737927800474284e-5 -5.662626510289128e-5 7.898041720349595e-5 0.00012663719742871956 4.565956176688293e-5; 6.040968241798923e-5 0.00010700802908907883 -0.00013008028668188204 0.00013264787798910653 -9.973736791941042e-5 0.00018789393311529747 6.609491197089898e-5 0.00029767339242319825 -0.00012588732693181139 -3.655107549243098e-6 2.019274347625e-5 9.936767354551735e-6 4.486933562369346e-5 3.891675464623005e-5 -7.401772359222851e-5 -0.0001895071216050977 -0.00011452335798018376 -3.1347884715765265e-5 7.583290084661267e-5 -1.939678236446106e-5 -0.00015249483453877967 -3.665886447485268e-5 1.4171150040815142e-5 0.0001292302442824028 0.00012168411581008939 -9.64592143201711e-5 -8.477287679048954e-5 -0.0002004259038141432 3.0735187691639164e-5 6.524801233248109e-5 2.0169673233644556e-5 -8.743083503865367e-5; -0.00015594568523967804 -6.432045322300979e-5 8.600838353002613e-5 -6.847963797783528e-5 1.2529190623792945e-5 -8.807312244712547e-6 3.45792120947208e-5 -2.4681085253650954e-5 -8.225049370823592e-5 -9.367357981547387e-5 -9.418858301178444e-6 -5.133474971673586e-5 0.00014421450194047788 -4.407725120701351e-5 8.811265595566455e-5 7.145731958869727e-5 -5.687893846927348e-5 0.000112704585060308 5.9515854243255995e-5 0.0001358120078088889 -7.36492349665442e-5 -0.00011210645445546061 8.18189817232252e-5 6.882221512340005e-5 -5.7000850777077406e-5 7.864945089520762e-5 1.1340554717642497e-5 4.324252384092998e-5 -0.0001591366146590361 -1.388631810416017e-5 3.255396748270674e-5 -0.0001389950177531495], bias = [-1.2853470297038883e-9, 4.896051452503787e-10, 1.6379466293720113e-9, 1.2138759146976405e-9, -9.218426927451368e-11, -1.8230331087992903e-9, 1.727273527097074e-9, -3.643656384310119e-9, 1.135604007608576e-9, 3.9509734757092134e-9, 7.081922327030148e-10, 1.5803600726850138e-9, 3.649090973270635e-9, -3.7907175855678225e-9, -4.630365361585973e-10, -4.123261859208235e-10, 4.150028400466572e-9, 3.0780554282139766e-10, 2.2980427059247367e-9, -1.0615544228007823e-10, -1.8675891400887143e-9, -1.4192107834140116e-9, -1.3080312335908598e-10, 7.813016500553129e-10, -2.1085295864516417e-9, -4.0488919442204864e-9, -3.073417879929848e-10, 3.228635512057411e-9, -3.5387966208828077e-9, -4.664963470446139e-10, -9.072315647865157e-11, -5.950119176721757e-10]), layer_4 = (weight = [-0.0006446651666109494 -0.0007529018339533018 -0.0008093233860063905 -0.0006284760118885408 -0.0005968810790197654 -0.000585493681187751 -0.000837673994838595 -0.0007711489063037479 -0.0005633050097127248 -0.0006366020175565896 -0.0006351336841893969 -0.0007250663187550846 -0.0006773039539336318 -0.0007102623027784417 -0.0006984477441356452 -0.0007836149331594133 -0.0006547691546035754 -0.0006358416473538963 -0.0007439842062632546 -0.0007248791323072195 -0.0009062966064743499 -0.000753728407564923 -0.0005576711033216811 -0.0007266704450221633 -0.0006264206496902814 -0.0006150284993179414 -0.0006485911114611045 -0.0005555057596760502 -0.000723135451689339 -0.0006958771766227974 -0.0007989775974893779 -0.0007756897085017334; 0.0002635064491133733 0.00013902321060088398 0.00027416288755319483 0.00036005136845857904 0.00026665041748154093 0.0003853675979829293 0.00020777463582142915 0.00024593198624691553 0.0003353161483059413 0.0003276606322724664 0.00015453736205649262 0.0002593210875760926 0.00010414091948133386 0.00012031890720450069 0.0003972533638995509 0.00012003154294842141 0.0003925132655046808 0.00023879112576484445 0.00026200251414252603 0.0002331804026158072 0.00031018298463556074 0.0002073419829842547 0.0003465754111874278 0.00015812656915209847 0.00019217357651452952 0.00011534668831958337 0.00010325712161342742 0.00021456723190503042 0.0001459926682109346 8.304199293465108e-5 0.00031881876799978747 0.00017356179484362758], bias = [-0.0006918680946778557, 0.00023882644622278632]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
    compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(
        ax,
        tsteps,
        waveform_nn_trained;
        marker=:circle,
        alpha=0.5,
        strokewidth=2,
        markersize=12,
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb,
    )

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
  JULIA_CPU_THREADS = 16
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 16
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.