Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux,
    ComponentArrays,
    LineSearches,
    OrdinaryDiffEqLowOrderRK,
    Optimization,
    OptimizationOptimJL,
    Printf,
    Random,
    SciMLSensitivity
using CairoMakie
Precompiling OrdinaryDiffEqLowOrderRK...
   3997.8 ms  ✓ OrdinaryDiffEqLowOrderRK
  1 dependency successfully precompiled in 4 seconds. 98 already precompiled.

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end
one2two (generic function with 1 method)

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end
soln2orbit (generic function with 2 methods)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end
d_dt (generic function with 1 method)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio  1 "mass_ratio must be <= 1"
    @assert mass_ratio  0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end
compute_waveform (generic function with 2 methods)

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defiing a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(
    Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-9.810709f-5; -4.749717f-5; 1.600672f-5; -2.0570134f-5; 8.47116f-5; -1.556238f-5; -3.1630112f-5; 0.00022141056; -9.744838f-6; 7.328204f-5; 0.00016253574; -2.312675f-5; 6.1258535f-5; -4.396796f-5; -1.0191237f-5; -5.0257106f-5; 0.00013160506; 8.645548f-5; 0.00014224235; 0.00010915904; 0.00017962937; 4.7297848f-7; 9.892074f-5; 8.767986f-5; 4.769619f-5; -3.2611486f-5; 4.0286137f-5; -2.9531187f-5; 0.00020246279; -7.734222f-5; -5.012816f-5; -9.175109f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-0.00017831847 0.00013048305 -0.00016112898 0.00012260898 -9.0377805f-5 -6.1589075f-5 0.0001146679 3.579894f-5 3.2997403f-5 -9.6493815f-5 3.049539f-5 1.048869f-6 6.756998f-5 -0.00011233788 0.00011161699 9.822821f-5 7.872565f-5 -8.817988f-5 -9.883017f-6 0.00014561195 0.00010579197 0.00010263919 -0.00013928692 3.6422385f-5 -1.6832617f-5 -9.294164f-5 -2.4132813f-5 -0.00020413958 0.0001615351 -1.3099786f-5 -4.7865222f-5 -4.432106f-5; 3.6371603f-5 6.631753f-5 7.0019036f-5 0.0002613199 0.00014353463 1.37724855f-5 0.00010948756 8.752543f-6 3.4356744f-5 -4.0200644f-5 -9.822746f-5 -0.00020784998 -2.105223f-5 4.1368123f-5 4.0607058f-5 -5.7999543f-5 1.8972818f-5 0.00011325985 0.0001860306 8.2499195f-5 -0.00020780491 -0.00020981889 2.5616904f-5 -5.331316f-5 -5.5564586f-5 -0.0002277269 8.710271f-5 -0.00015037943 -5.357252f-5 4.1597046f-5 -0.000110820394 5.427798f-5; 2.1553818f-5 -3.50902f-5 -0.00012888659 0.00012314774 3.8033318f-6 6.325454f-5 -1.0414451f-5 6.0419025f-5 0.000117232405 8.634658f-5 1.3190209f-5 -6.2175f-5 -4.6297544f-5 0.00019949918 -0.00010461776 8.584455f-5 8.585367f-5 4.089348f-5 2.6338364f-6 4.7280013f-5 -0.00017275818 0.00013958104 -4.4745244f-5 5.095953f-5 2.5381836f-5 -9.704044f-5 -2.5862842f-5 -9.144255f-6 1.3386219f-5 -1.7414086f-5 3.3762914f-5 6.6373614f-6; 2.2517666f-5 -1.1072251f-5 -0.00012464053 0.000112476846 2.5459406f-6 -4.858711f-5 7.13806f-5 1.2623556f-6 1.5814723f-5 0.00013187755 0.00014069339 -0.0001532483 2.6222475f-5 0.00015579934 -0.00012208894 7.081364f-5 9.600022f-5 -1.6357959f-5 -0.00024375167 5.0704006f-5 -9.393173f-5 -9.171542f-5 -3.959161f-5 4.1147745f-5 -0.00010135991 -6.1357096f-5 4.280506f-5 4.5649194f-5 -0.00012840093 -0.00011396593 -0.00010041804 -0.00013113317; -4.0188865f-5 1.0492269f-5 -6.9983646f-5 0.00010012844 4.922639f-5 5.9135502f-5 -2.1056503f-5 0.00012527317 0.000101697835 -1.7230204f-5 -1.7405786f-5 -4.1217314f-5 7.735634f-6 -0.00011539236 -3.6085963f-5 6.3031956f-5 2.6408907f-6 0.00010528684 6.242802f-5 0.000111430636 -2.8434462f-5 -1.5672136f-5 0.00021007567 0.00010317422 -4.0257793f-5 8.3446655f-5 -3.2094365f-5 0.000114282404 -0.00013151439 -0.00011388785 1.5004865f-5 -0.00010683882; -6.6293855f-5 -3.2051278f-6 1.3235997f-5 6.298265f-5 -0.00014937457 -1.919268f-5 -1.18368525f-5 -7.385695f-6 0.00016536523 8.039409f-6 8.6282955f-5 0.0001296013 -3.52472f-5 -0.00011726266 0.00012110018 -5.5182092f-5 0.0001232943 0.00014545997 -8.299895f-5 7.9142585f-5 -5.3638207f-5 -0.00016157863 -0.00026922466 -0.0001930139 3.3256433f-6 3.0414687f-5 7.541309f-5 -7.1066584f-5 9.803255f-5 -2.7930666f-5 -6.7389005f-6 6.7134955f-5; -0.00010834147 -7.178083f-5 -0.0001102978 5.716197f-5 6.9371636f-5 -7.7073695f-5 4.5339293f-5 5.8977548f-5 9.24073f-5 3.7500453f-5 -0.00019697407 -0.00022948452 -4.8799062f-5 -6.148491f-5 -4.369149f-5 -6.691656f-5 -2.662065f-5 8.612358f-5 -0.00016693558 8.440058f-6 5.852841f-5 8.255548f-5 -8.3711195f-5 -3.71884f-5 -3.4188775f-5 8.050637f-5 -2.5978943f-5 0.000113401446 -5.952774f-5 5.737583f-5 -5.5282308f-6 -0.00019618808; 8.458307f-5 -5.4504842f-5 6.125127f-5 -3.920111f-5 4.379538f-5 -9.889507f-5 1.2756687f-5 8.411042f-5 -0.00015521604 0.00012297934 6.64283f-5 8.542779f-5 3.111206f-5 -8.1984785f-5 -4.4515346f-5 -5.658797f-5 0.00011596129 7.35414f-5 2.2176462f-5 -8.755205f-5 2.4384884f-5 5.618803f-5 -1.3446258f-6 -2.2676102f-5 -8.585275f-5 0.00012099342 -1.7545095f-5 1.030256f-5 8.033761f-5 1.3839005f-5 -1.1059399f-5 3.2643882f-5; 2.8487544f-5 0.00010866878 0.00017608773 -7.7042285f-5 -1.032003f-5 -2.7283668f-5 -0.00019456362 -4.5505596f-5 0.00011511678 0.00015239358 4.2203214f-5 -2.2311556f-6 2.4891428f-5 -0.00018206543 -1.3418422f-5 -5.6423625f-5 -2.0089767f-5 6.8744026f-5 -6.0993374f-5 2.0996134f-7 -0.00012740446 0.00012638375 -0.00014958643 2.9093984f-5 2.8274326f-5 -0.0002014046 -0.0002406088 7.101039f-5 0.00017910643 -6.8285357f-7 0.0001489658 -0.000105213476; 0.00013806055 7.951105f-5 -7.2372306f-5 8.5247586f-5 7.266868f-6 -6.89452f-5 -0.00011507749 -8.706957f-5 -7.1678165f-5 -5.970952f-5 -3.7076417f-5 -8.006767f-5 -0.00017664685 -0.00013367417 -3.2387605f-5 0.00016269954 -0.00021597625 0.0002264972 0.000168504 -1.9822182f-6 1.7750448f-5 0.00012274506 0.00015425276 -2.1529306f-5 -0.00024816158 -2.3959448f-5 6.8709174f-5 -0.00023176367 -5.507069f-5 0.0001449904 1.9338326f-5 -0.00010870709; -3.6332814f-5 -5.783667f-5 4.6357833f-5 7.593764f-5 -9.931773f-5 0.00015098549 3.416752f-5 -0.00010515108 9.021634f-5 -1.323271f-5 0.00018576968 -8.644768f-5 0.000177251 4.528287f-5 7.2876384f-5 -1.3869445f-5 -0.00014132507 -9.6190866f-5 8.740329f-5 -3.3251534f-5 9.6230455f-5 9.040136f-6 -0.00017988752 -0.00021067941 1.9947169f-5 5.112547f-5 -8.649439f-5 0.00013431573 9.594309f-5 -1.7656595f-5 1.931575f-5 -8.649698f-5; 0.00014105525 -0.00010198477 3.4284042f-6 -5.7305224f-5 -0.00011009428 4.8797632f-5 0.000112590336 -0.00016347916 9.7031385f-5 0.00014390679 -0.00016031471 0.00016881159 4.4573033f-5 -6.288785f-6 4.89793f-5 -4.143615f-5 -2.6298994f-5 2.385294f-5 7.1207887f-6 3.5189063f-5 -9.8048426f-5 0.00023028531 -4.8888327f-5 9.882595f-6 -0.00011440133 3.1167325f-5 1.549816f-5 -8.0829144f-5 5.3974636f-5 -3.8592934f-5 0.000103636485 -0.00011887275; -4.726439f-5 6.3562846f-5 -0.00014710319 -7.059829f-5 8.022609f-5 2.930132f-5 0.00016946891 -0.00020918473 2.81718f-5 -0.00010618494 -6.9092515f-5 -0.00013334276 0.00011526275 -0.0001062301 5.5350607f-5 9.170354f-5 1.978034f-6 0.000102813676 -4.537009f-5 -0.000114416136 4.772478f-5 3.9205595f-5 7.0800415f-5 9.731428f-5 -7.601688f-5 -8.5779255f-5 -0.00013421553 4.206156f-6 -1.31570205f-5 -9.285651f-5 0.00016286838 -2.8885068f-5; -8.230358f-5 0.00018240756 0.00023388234 4.331204f-6 0.00014469658 -2.8933624f-5 4.9720078f-5 -3.661517f-5 -8.603343f-5 -1.3673653f-5 9.111441f-5 -0.00010850325 -4.5236276f-5 0.00016326157 -0.00010101279 -6.962656f-5 9.863074f-5 -4.334359f-5 5.3590997f-5 4.9475613f-5 0.000106400475 -0.00011895285 -4.055545f-6 -1.7906134f-5 -0.00017077343 -0.00011536041 -2.5196985f-5 -5.195108f-5 1.705539f-5 9.28549f-5 -7.5333846f-5 -9.226754f-5; 9.06431f-5 -5.2244268f-5 -5.553575f-5 4.6864232f-5 1.9666355f-5 -0.0002761014 0.00016611224 7.1211034f-6 7.486165f-5 -7.495144f-5 7.169138f-5 2.7971522f-5 -3.501307f-5 -5.1268064f-5 0.000106179126 -0.00022141472 -7.3631876f-5 4.9716276f-5 -0.0001017034 -0.00011749811 -0.00010177304 -0.00021228686 0.00011560145 0.00013458572 -0.00015890168 0.00015659345 -0.00019361424 -5.5203738f-5 5.5676686f-5 -4.0314512f-7 -1.5286634f-5 0.00019765651; -3.448417f-5 -0.00014828869 -0.00020502029 -2.398188f-5 -0.00012325992 1.6724123f-5 -0.00014364115 -0.00011470666 -0.000119083736 1.0166186f-5 -4.8746442f-5 0.00017347891 7.415056f-5 -1.4329463f-5 -0.00014800715 5.811935f-5 -0.00020270864 7.730159f-5 3.6368383f-5 0.00020435649 5.4111915f-5 6.47488f-5 -8.871129f-5 -6.2516534f-5 1.0232614f-5 0.00011891945 0.00015900124 7.6371754f-5 1.2643379f-5 0.00015081598 -9.473879f-8 -2.9365518f-5; 0.00014368795 -0.00017664148 3.6761492f-5 -3.83041f-5 -2.9995585f-5 -1.2543608f-5 9.558004f-6 9.525164f-6 -1.4304904f-5 0.00017951602 1.2204288f-5 4.3094165f-5 -0.00011037553 6.708922f-5 0.00010106517 -2.9556986f-5 -1.6057296f-5 -1.3431018f-6 4.7254365f-5 -4.1416304f-5 -9.181493f-6 0.00012376747 0.000101125785 -3.3063472f-5 -9.868658f-5 -5.4863827f-5 5.2404688f-5 3.641202f-5 1.7690727f-5 9.175237f-5 3.7718906f-5 -3.2505937f-5; -0.00013234185 -1.6075597f-5 8.1915096f-5 -0.00019550318 -3.5431285f-6 -9.165008f-5 5.2527645f-5 -1.02270715f-5 6.9989874f-5 -0.00017549437 4.416361f-5 8.6359505f-5 -2.3016526f-5 9.414863f-5 3.7319754f-5 7.428378f-6 9.207705f-6 0.00011332948 1.9848918f-5 1.3857514f-5 -2.2603715f-6 3.0573643f-5 -3.426298f-5 -0.00014821983 -0.000110408444 -4.9823422f-5 -6.666585f-5 -2.2594968f-5 0.0001503234 -5.9922877f-5 7.380893f-5 4.8794787f-5; -4.2962045f-5 -4.0543695f-5 -6.856471f-5 2.9472263f-5 -5.966639f-5 0.00021193914 0.0002819889 4.4656415f-5 2.1428383f-5 0.0001236945 1.2834069f-5 -0.0001069364 1.8655579f-5 3.3167966f-5 1.8353108f-5 -0.00014206959 -8.2008795f-5 7.183409f-5 4.332012f-6 3.6658384f-5 8.797239f-5 3.6662982f-7 4.6839225f-5 -0.000103775776 6.764079f-5 -6.746253f-5 2.1805345f-5 2.7018561f-5 -4.9196078f-5 -0.00022031223 7.0670256f-5 8.173493f-5; -6.391918f-5 -6.620508f-5 0.0001224764 -8.386396f-6 -3.734335f-5 -0.00021373911 1.8073753f-5 1.5932723f-5 1.5025505f-5 -0.00013454967 -0.0002356609 0.00016592018 -1.6299513f-5 -7.31518f-5 1.0992012f-5 -7.4780415f-5 5.5459486f-6 -3.243614f-5 0.00013386854 7.138719f-6 6.0752798f-5 -2.4450148f-6 -4.4769577f-6 5.4034135f-5 -0.000201196 -0.00015252398 -5.5589342f-5 0.00013092643 -7.0941656f-5 7.7901466f-5 -8.731441f-6 0.0001788312; 0.0002031958 -0.00013252842 3.3278913f-5 -0.00010553443 -1.656765f-5 -7.522686f-5 0.00017397261 -0.00012271895 -5.5088327f-5 7.315382f-5 -1.373772f-5 0.00018888527 -7.855747f-5 -5.5263754f-5 1.0922884f-5 -5.0371633f-5 0.00010006879 -9.952291f-7 -9.4112525f-5 6.212597f-5 7.856212f-6 4.6229103f-5 7.2144985f-6 0.00027732033 -0.00010643825 2.6562744f-5 7.348312f-5 -0.00010819968 -4.662509f-5 1.42356985f-5 5.810332f-5 -5.7966456f-5; -1.4724547f-5 0.00014941843 7.752716f-5 1.5741793f-5 -6.011349f-5 -7.4907715f-5 -1.9248184f-5 0.0001524555 -0.0002414251 -0.00018178033 0.00021011547 -6.34535f-5 -0.00028005798 5.6202607f-6 -0.00013171678 -3.4294528f-5 -0.000127753 -5.8847352f-5 -1.0592568f-5 9.70805f-5 -6.5364395f-5 7.441079f-5 -0.0001347843 -0.00012571308 0.00012879855 -6.1195846f-5 4.394772f-5 -6.3580985f-5 2.0129697f-5 1.5130864f-5 0.00013361538 -2.9441948f-5; 0.00010349854 9.909467f-5 8.710504f-5 3.957707f-5 6.0987983f-5 2.5881262f-5 -8.561335f-6 -0.000117059426 -2.0855887f-5 -5.8820126f-5 -0.00016662043 -4.5370038f-5 -4.6271485f-5 2.9613197f-5 0.00022900189 0.0002517378 -4.9606144f-5 7.005932f-5 7.6633165f-5 -5.9095997f-5 1.4510472f-5 -5.3626703f-5 -4.5814095f-6 0.00010735821 3.309915f-5 -5.1771617f-6 1.6324844f-5 8.341973f-5 0.00024525632 0.00022049803 6.0460756f-5 -0.00011575212; -0.0001170186 0.0002235445 8.813937f-5 0.00017171876 -0.00012718904 0.00011505093 -2.077857f-5 3.649261f-5 1.0523547f-5 8.152253f-5 -0.00014199223 0.00016379097 -0.00016070351 7.9805955f-5 8.807329f-5 -0.00015770573 1.6006173f-6 -0.00021022714 0.00012787544 -0.00012445668 -3.6981997f-5 -3.3070017f-5 5.5262662f-5 -3.689487f-5 -0.0001348591 2.4601395f-5 -0.00012493394 5.460138f-5 -0.00026866814 4.1382514f-6 1.7270982f-5 -8.6257045f-5; -0.00016867605 5.0551757f-6 -8.461162f-5 -0.00013218026 0.00012631598 0.00028391412 3.4609726f-5 8.829556f-5 -6.50602f-5 -0.00016086297 -1.7267752f-5 0.00012911898 2.9222349f-5 0.00014630069 7.5276344f-5 0.00015988052 8.485604f-5 -0.00017669296 9.1420465f-5 -4.5420576f-5 5.766383f-7 6.3196196f-5 -4.5319535f-5 -3.283137f-5 7.53801f-5 -1.6106547f-5 1.23243235f-5 1.8058568f-5 7.162285f-5 0.00011673928 -0.00012657577 -0.00030242157; -0.00016350791 -0.00021130503 8.443801f-5 0.00020810301 1.1664595f-5 -7.057608f-5 2.8012932f-5 -4.4058645f-5 0.00011749628 7.407136f-5 4.6954352f-5 -4.7776142f-5 9.9503544f-5 4.40487f-5 -0.00012847455 6.855877f-5 -4.541757f-5 -9.812558f-6 8.370624f-5 -7.250004f-5 0.0001356357 2.3445908f-5 9.556296f-5 5.6949386f-5 -0.00015990589 -3.0129111f-5 0.0001219949 0.00011901714 -0.00014171422 -3.3430075f-5 -7.1002265f-5 -0.00012150418; 8.4324834f-5 0.00026151573 9.2143135f-5 -0.00013689007 -4.848433f-5 8.8735855f-5 -5.4888093f-5 -2.5099755f-5 5.953152f-5 1.7843387f-5 7.7917684f-5 6.3396124f-5 9.561543f-5 -0.00012852145 3.675665f-5 0.000162002 -0.00011043555 7.566079f-5 0.00026878476 -0.00018029648 -0.00013366198 0.00018316452 0.00018161481 2.426654f-5 -0.00011179578 -9.626906f-5 -3.9661227f-5 3.4301327f-5 -4.5936726f-5 -4.0444444f-5 -8.428332f-6 2.8090584f-5; 0.00012333789 -3.5217436f-5 7.4736395f-6 -4.237652f-5 0.00024973322 -0.00016850958 4.5044042f-5 0.00017318233 -5.709483f-5 -5.795766f-5 -5.331043f-5 -9.3429975f-5 -9.351519f-5 -6.8994115f-5 -0.00013938319 3.3829023f-5 -5.070276f-5 -4.8729293f-5 6.1856124f-5 1.5715927f-5 4.8659553f-5 0.00019023396 0.00017565765 -9.050845f-5 -3.2672685f-6 8.9474364f-5 -8.118979f-5 0.00012850513 -2.6361638f-5 -3.827526f-5 0.00011786288 -3.9273495f-5; 7.5717304f-5 -0.00015946805 -0.0003157078 0.00011017467 0.00015334766 7.246223f-9 3.540672f-5 -8.0845304f-5 0.00014456589 -4.853914f-5 -0.00013660434 -9.431877f-6 0.00015708027 6.551907f-5 4.6072793f-5 0.00014523002 -7.181377f-5 9.87449f-5 -0.00012234181 -0.00021378357 0.00015101807 0.00013680985 4.2787397f-5 -0.00012591926 0.000126174 -7.575162f-5 -4.2705018f-5 -5.0344086f-5 -4.7568115f-5 0.000116251664 0.00014546647 -0.00014367663; 1.6718592f-5 -9.135359f-6 -0.00015166422 0.00016329678 -8.5307045f-5 -6.322898f-5 8.342223f-5 0.00010435035 -1.8903153f-5 -0.00015656784 5.0264392f-5 2.3026989f-5 -6.9934285f-5 6.7463654f-5 -7.578899f-5 0.00020811461 -9.1383146f-5 -1.0786738f-5 7.713044f-5 2.710944f-5 7.718914f-5 7.385855f-5 3.0626423f-5 -5.8603076f-5 -6.245301f-5 -2.20037f-6 0.00016838421 5.1463277f-5 0.00011029318 -6.975102f-5 -0.0001119742 -0.00010929362; 6.3219275f-5 5.031506f-5 -8.269195f-5 -3.4634202f-5 -2.7699145f-5 6.0748002f-6 -6.6148525f-5 -0.00014981315 0.00010059232 4.3198303f-5 2.3191853f-5 2.522342f-6 -2.5354264f-5 0.00022769372 5.1528616f-5 6.357483f-5 -3.8362967f-5 -0.00020531793 7.4727395f-5 -2.9437f-5 -7.824605f-6 0.00013268444 -0.00014984203 0.00012942351 -0.00011811632 7.803113f-5 -6.322592f-5 7.315007f-5 6.3714388f-6 -1.8846746f-5 3.2029653f-5 -1.726757f-5; 1.7118775f-5 4.541132f-5 -2.9635958f-5 9.491427f-5 0.00017626355 0.00013812688 -1.681752f-5 6.431639f-5 -0.00012204213 5.7002686f-5 1.4878468f-5 -4.558555f-5 -2.9358425f-5 -0.00010720979 0.00012627129 9.329562f-5 -5.0097384f-5 -8.3081264f-5 2.8971048f-5 -4.207406f-5 5.162851f-5 5.5022578f-5 -3.489694f-5 -0.00022156161 0.0001259785 3.805466f-5 0.00014479135 0.00013058579 -0.00016514637 0.00012055444 4.983819f-5 -3.3389726f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[-9.222426f-5 -4.8276735f-5 4.298229f-5 -6.709387f-5 0.00022337107 -1.6993967f-5 -8.882483f-5 0.00014664381 0.00016097071 -0.00012020253 -4.2017553f-5 -0.000112783295 -5.8513804f-5 -5.4845026f-5 -7.675401f-5 1.1575542f-6 7.843759f-6 -0.0001140299 -4.6318204f-5 -0.00011077231 8.486885f-6 7.148576f-5 3.8911076f-5 -0.0001858529 -0.00014754421 1.777367f-5 -0.00015123507 -5.555136f-6 -4.1275987f-5 -0.000107691914 4.7523114f-5 0.00013522271; 3.1594976f-5 -1.7241637f-6 -7.091568f-5 -1.564293f-5 0.00016775372 1.4364453f-5 0.00010334371 -0.0002275463 -6.437595f-5 -2.3377805f-5 7.552507f-8 -0.00025938606 0.00015045752 8.9680805f-5 -1.240913f-5 2.2359103f-5 -3.0009169f-5 0.000103488754 7.0405935f-5 8.4785905f-5 3.1545867f-5 9.4464216f-5 0.00012497921 5.9486458f-5 -5.787111f-5 -5.5649038f-5 -7.593884f-5 0.0001736212 7.846955f-5 -8.5739906f-5 -8.473465f-6 -6.4856315f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most DL frameworks, Lux defaults to using Float32, however, in this case we need Float64

julia
const params = ComponentArray(f64(ps))

const nn_model = StatefulLuxLayer{true}(nn, nothing, st)
StatefulLuxLayer{true}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),  # 64 parameters
        layer_3 = Dense(32 => 32, cos),  # 1_056 parameters
        layer_4 = Dense(32 => 2),       # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"];
        position=:lb,
    )

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end
loss (generic function with 1 method)

Warmup the loss function

julia
loss(params)
0.0006912158143264871

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end
callback (generic function with 1 method)

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob,
    BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback,
    maxiters=1000,
)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-9.810709161674188e-5; -4.749717118102386e-5; 1.60067193064572e-5; -2.0570134438462835e-5; 8.471160253970149e-5; -1.5562380212918337e-5; -3.163011206197262e-5; 0.00022141056251712498; -9.744838280311785e-6; 7.328204083018828e-5; 0.0001625357399461849; -2.312674951098427e-5; 6.125853542467284e-5; -4.3967960664294466e-5; -1.019123737932146e-5; -5.025710561304781e-5; 0.00013160506205165198; 8.645548223285593e-5; 0.00014224235201236646; 0.00010915903840219251; 0.0001796293654478892; 4.729784848219133e-7; 9.892074012888058e-5; 8.767985855227449e-5; 4.7696190449608475e-5; -3.261148594901416e-5; 4.0286136936604864e-5; -2.9531187465162272e-5; 0.00020246279018451286; -7.734222162978408e-5; -5.012816109228266e-5; -9.175109153150993e-5;;], bias = [-1.0831029703752683e-16, 2.7255306112797796e-18, 3.843953745337599e-18, 5.9260630322530634e-18, 5.915278771255347e-17, -2.561487105259337e-17, -3.2593657346144397e-18, -3.910333795227343e-18, -4.556269501842252e-19, 5.3223650958815485e-17, -3.765765047677944e-18, -1.0271805338083189e-17, 4.527914306103736e-17, -1.069277069670921e-17, -1.0695817468953485e-17, -8.849797284554698e-17, 2.6990803502192818e-17, -7.18805924313449e-18, 1.997448595589829e-16, -5.924491386626335e-17, 2.052701074432037e-17, 4.486857932208938e-19, 8.77150120164852e-17, 1.8416390524110923e-17, 3.76542289325031e-17, -4.991316939695206e-18, 2.725804783071943e-17, -2.1296629555395842e-17, 1.5892369374058714e-16, -2.374566104622497e-17, -6.80606440027955e-18, 3.5279400438530625e-17]), layer_3 = (weight = [-0.00017831848735941678 0.00013048303453696783 -0.0001611289938401628 0.00012260896410306807 -9.037782210644442e-5 -6.158909154610821e-5 0.00011466788209954165 3.57989244556791e-5 3.2997386097934856e-5 -9.649383200585951e-5 3.049537344931528e-5 1.0488523238504952e-6 6.756996090270468e-5 -0.00011233789909884445 0.00011161697120850793 9.822819056434445e-5 7.872563373975352e-5 -8.817989355439428e-5 -9.883033543798212e-6 0.00014561193320413948 0.00010579195595791468 0.00010263917438432461 -0.0001392869328802648 3.6422368521829004e-5 -1.6832633536165753e-5 -9.294165856691189e-5 -2.4132829184151434e-5 -0.00020413959491552894 0.00016153508995639073 -1.3099802617131415e-5 -4.786523850189631e-5 -4.432107589228185e-5; 3.6371390686131986e-5 6.631731864925795e-5 7.001882384227895e-5 0.0002613196765903787 0.00014353442291193064 1.3772273612435241e-5 0.00010948735050786598 8.752331248423443e-6 3.435653250361166e-5 -4.0200856367132e-5 -9.822766971611863e-5 -0.00020785019389424166 -2.105244222749626e-5 4.1367910833091864e-5 4.060684566664814e-5 -5.799975499231978e-5 1.897260606509591e-5 0.00011325963621655834 0.00018603038095333125 8.249898289958401e-5 -0.00020780512661277724 -0.00020981909712847018 2.5616692235222072e-5 -5.3313372538663596e-5 -5.5564797569072396e-5 -0.00022772710601404 8.710249583399239e-5 -0.00015037964569475194 -5.3572731323776495e-5 4.1596834287507195e-5 -0.00011082060611644439 5.42777697571962e-5; 2.1555526925327938e-5 -3.508849249806992e-5 -0.0001288848811807453 0.00012314944478679745 3.8050402483697315e-6 6.325625008434362e-5 -1.0412742542719628e-5 6.0420733004514516e-5 0.00011723411307433934 8.634829145852499e-5 1.3191917330210836e-5 -6.217329201059126e-5 -4.6295835548116633e-5 0.00019950088471636383 -0.00010461604824991582 8.584625765910552e-5 8.585538170994555e-5 4.089518779186438e-5 2.6355448855886495e-6 4.7281721433351166e-5 -0.00017275647542444583 0.00013958274803120964 -4.474353545698557e-5 5.0961238682464576e-5 2.538354465035915e-5 -9.703873502326868e-5 -2.586113396569857e-5 -9.14254678969699e-6 1.33879270808554e-5 -1.7412377476613147e-5 3.3764622000709114e-5 6.639069849586526e-6; 2.2515684186640737e-5 -1.1074233434066183e-5 -0.00012464251215484447 0.000112474863790828 2.5439583810493518e-6 -4.858909181117536e-5 7.137861809545719e-5 1.2603733881147184e-6 1.5812740976377394e-5 0.00013187556448478129 0.000140691405768444 -0.0001532502823733582 2.6220492664414947e-5 0.0001557973560652847 -0.00012209092113104182 7.081166092624576e-5 9.599823949479209e-5 -1.635994106387245e-5 -0.0002437536490373975 5.0702023664185e-5 -9.393371458776554e-5 -9.171739969084512e-5 -3.95935925821106e-5 4.1145762743819744e-5 -0.00010136189251121459 -6.135907772610368e-5 4.280307673119573e-5 4.564721218279148e-5 -0.00012840291622472625 -0.00011396791108332662 -0.0001004200197980584 -0.0001311351547241949; -4.018700879118062e-5 1.0494125079790522e-5 -6.998178964880321e-5 0.00010013029582166077 4.922824568762933e-5 5.913735781047467e-5 -2.105464723302589e-5 0.0001252750303579293 0.00010169969077520988 -1.7228348452987688e-5 -1.7403930043156356e-5 -4.121545812398239e-5 7.737490294461564e-6 -0.00011539050269805951 -3.608410720647331e-5 6.303381137383008e-5 2.642746538690979e-6 0.00010528869511159609 6.242987778801179e-5 0.00011143249189313706 -2.843260649702127e-5 -1.5670280346451444e-5 0.00021007752416022458 0.00010317607716257211 -4.025593757563419e-5 8.344851134876646e-5 -3.209250958337558e-5 0.0001142842598242772 -0.00013151253319980253 -0.00011388599473451635 1.5006721285831421e-5 -0.00010683696696236314; -6.629429572453099e-5 -3.205568877405081e-6 1.32355562488155e-5 6.29822125080463e-5 -0.00014937500691698276 -1.919312035565241e-5 -1.1837293645243282e-5 -7.38613608003688e-6 0.0001653647877862167 8.038968185843936e-6 8.628251361702498e-5 0.00012960086587159456 -3.524764270363767e-5 -0.0001172630991316908 0.00012109973885858917 -5.518253329168483e-5 0.00012329386208480956 0.00014545952853555453 -8.299939236561925e-5 7.914214347297982e-5 -5.3638647845527e-5 -0.00016157907248666025 -0.00026922510460676 -0.00019301433804170564 3.325202210959339e-6 3.0414245880507283e-5 7.541264767101544e-5 -7.106702560517441e-5 9.803210832220643e-5 -2.7931107073839633e-5 -6.739341646920097e-6 6.713451420697057e-5; -0.00010834430884542873 -7.178366957206026e-5 -0.00011030063922490856 5.7159130502520966e-5 6.936879492158779e-5 -7.707653583466939e-5 4.533645244108651e-5 5.897470748199078e-5 9.240446217617074e-5 3.7497612193958664e-5 -0.00019697691039153675 -0.0002294873587551704 -4.880190261147524e-5 -6.148775299053417e-5 -4.369432952345294e-5 -6.691940086867383e-5 -2.6623490501965676e-5 8.612074145375554e-5 -0.00016693842497160602 8.437217385198293e-6 5.8525569894409326e-5 8.255264094358353e-5 -8.371403544415703e-5 -3.719124231693511e-5 -3.419161601314998e-5 8.050352755261728e-5 -2.5981783782218504e-5 0.0001133986048635853 -5.953058223792236e-5 5.7372989086701e-5 -5.5310714550707744e-6 -0.00019619091779432367; 8.458452707560541e-5 -5.45033880829889e-5 6.125272107016538e-5 -3.9199655430305247e-5 4.379683291478004e-5 -9.889361622338871e-5 1.2758141610508087e-5 8.411187359301212e-5 -0.00015521458494771018 0.00012298079586759217 6.64297540875427e-5 8.54292439267435e-5 3.111351456897415e-5 -8.198333030162212e-5 -4.4513891180425224e-5 -5.658651657580258e-5 0.00011596274184699552 7.354285394245669e-5 2.217791647995603e-5 -8.755059475373773e-5 2.4386338781869633e-5 5.6189484029861866e-5 -1.3431715164533908e-6 -2.2674647835116918e-5 -8.585129485294073e-5 0.0001209948741685061 -1.754364070660254e-5 1.0304014767431481e-5 8.033906401537571e-5 1.3840459405004293e-5 -1.1057944857779429e-5 3.264533647740124e-5; 2.8486738792111456e-5 0.00010866797178036185 0.00017608692954976642 -7.704309015085384e-5 -1.0320835336330222e-5 -2.7284473126800072e-5 -0.00019456442301876297 -4.5506401384631866e-5 0.00011511597634975396 0.00015239277357917942 4.220240820300111e-5 -2.2319608985148006e-6 2.4890622207151433e-5 -0.00018206623413849048 -1.341922730776652e-5 -5.642443053207701e-5 -2.0090571943595592e-5 6.87432207444941e-5 -6.0994179385172893e-5 2.0915602075987687e-7 -0.00012740526786787936 0.0001263829422344692 -0.00014958723246368734 2.9093178945316722e-5 2.8273520492204718e-5 -0.00020140540933485525 -0.00024060960773627645 7.100958695386153e-5 0.00017910562250049507 -6.836588892445835e-7 0.00014896499718756073 -0.00010521428108464184; 0.000138058971667098 7.950947171299721e-5 -7.237388793622244e-5 8.524600404304361e-5 7.265286209776769e-6 -6.894678093270499e-5 -0.00011507907341477521 -8.707115496982273e-5 -7.167974702791418e-5 -5.971110234279952e-5 -3.707799927107675e-5 -8.006925357516008e-5 -0.00017664843047280717 -0.0001336757516885271 -3.238918664759132e-5 0.00016269795438963005 -0.00021597783355284638 0.00022649561712275533 0.0001685024223359301 -1.9838001510171467e-6 1.77488661090905e-5 0.0001227434737909481 0.00015425117878000642 -2.1530887655198537e-5 -0.00024816316302643147 -2.396103021349816e-5 6.870759231493778e-5 -0.00023176525003985572 -5.507227191011288e-5 0.00014498881197757285 1.933674381395353e-5 -0.00010870867420929163; -3.633235475022046e-5 -5.783621101031603e-5 4.6358292354535965e-5 7.593810140647433e-5 -9.931727118996568e-5 0.000150985951666054 3.4167978344750204e-5 -0.00010515061741203359 9.021679715047102e-5 -1.3232250705504426e-5 0.0001857701367840921 -8.644722266794679e-5 0.00017725146016132475 4.528333146267984e-5 7.287684410371329e-5 -1.3868985217731123e-5 -0.00014132461229778638 -9.619040657739756e-5 8.740375186575718e-5 -3.3251074012105496e-5 9.623091453945227e-5 9.040596057867865e-6 -0.0001798870566857204 -0.00021067895296469002 1.9947628504861508e-5 5.112592906484807e-5 -8.64939270398724e-5 0.00013431619435109624 9.594355059347878e-5 -1.765613570782934e-5 1.9316209989166215e-5 -8.649651728078408e-5; 0.00014105580134955584 -0.00010198421864597577 3.42895509593297e-6 -5.730467296498315e-5 -0.00011009373179231378 4.879818311483185e-5 0.00011259088728284034 -0.0001634786042367609 9.703193586492498e-5 0.00014390733645004643 -0.00016031415930311144 0.0001688121390080596 4.457358386307478e-5 -6.288233863324656e-6 4.89798492245428e-5 -4.143559840567866e-5 -2.6298443345607e-5 2.385349011919848e-5 7.1213396487075145e-6 3.5189613981035824e-5 -9.80478746450046e-5 0.00023028586087921223 -4.8887776075300095e-5 9.883145820912084e-6 -0.00011440077855777447 3.116787570352236e-5 1.5498711717645894e-5 -8.08285933552725e-5 5.397518702838361e-5 -3.8592383362721246e-5 0.00010363703563516505 -0.00011887219648169915; -4.726552569103302e-5 6.35617115310751e-5 -0.00014710432068590297 -7.05994235342949e-5 8.022495686396989e-5 2.930018573330027e-5 0.00016946777842327757 -0.00020918586150421115 2.8170666073274382e-5 -0.00010618607401760329 -6.909364955411477e-5 -0.00013334389867872214 0.00011526161771122274 -0.00010623123588651578 5.5349472777161334e-5 9.170240892650421e-5 1.976899884327449e-6 0.00010281254154484764 -4.537122285021883e-5 -0.00011441727027907759 4.7723647103132984e-5 3.920446116331733e-5 7.079928119288886e-5 9.731314722536183e-5 -7.601801106851694e-5 -8.578038895415335e-5 -0.00013421666434645826 4.205021841480967e-6 -1.3158154697724344e-5 -9.285764744184922e-5 0.00016286724795441934 -2.8886202099089645e-5; -8.230357949255733e-5 0.00018240756011468585 0.00023388234089029248 4.331205073273736e-6 0.00014469657739059295 -2.893362276122673e-5 4.9720079314599286e-5 -3.66151677186764e-5 -8.603343181644333e-5 -1.3673651355614427e-5 9.111441007948714e-5 -0.00010850324784737778 -4.523627527273958e-5 0.00016326157614489313 -0.00010101278772677061 -6.962656212604372e-5 9.863074350537034e-5 -4.334358769451966e-5 5.3590997904709075e-5 4.947561441472033e-5 0.0001064004767070827 -0.00011895284720563096 -4.05554388603662e-6 -1.790613228845644e-5 -0.00017077343241876696 -0.00011536040861724642 -2.5196983787876095e-5 -5.1951078293907915e-5 1.705539156350538e-5 9.285489917633356e-5 -7.533384510648205e-5 -9.226753778389423e-5; 9.064140186716956e-5 -5.224596742585555e-5 -5.553745057944572e-5 4.686253242731672e-5 1.966465543852672e-5 -0.00027610310343255656 0.00016611053734856334 7.119403845439659e-6 7.485995166540706e-5 -7.495313610809937e-5 7.168967868972718e-5 2.7969822723084378e-5 -3.501476811755619e-5 -5.1269763106697156e-5 0.00010617742674263755 -0.00022141642392774346 -7.363357571114207e-5 4.971457684484346e-5 -0.00010170510057753145 -0.00011749980757406633 -0.00010177473876785015 -0.00021228856047776905 0.00011559974819726474 0.00013458401673312342 -0.00015890337516715066 0.00015659175124789268 -0.000193615935758439 -5.5205437719573774e-5 5.56749867527189e-5 -4.0484468548315e-7 -1.528833371515625e-5 0.0001976548113005896; -3.448492556242466e-5 -0.00014828944462249427 -0.00020502104656668803 -2.3982637357110623e-5 -0.00012326067429865035 1.672336598960259e-5 -0.0001436419121445197 -0.00011470741504932913 -0.00011908449290683922 1.0165428976618501e-5 -4.874719947008197e-5 0.0001734781538844998 7.414980005807837e-5 -1.4330220109402765e-5 -0.00014800790871857163 5.811859169727756e-5 -0.00020270940207308164 7.730083540184175e-5 3.6367625726713525e-5 0.00020435573184033717 5.411115795243894e-5 6.474804045967327e-5 -8.871204830836695e-5 -6.251729122493659e-5 1.0231856650645832e-5 0.00011891869449857219 0.00015900048252762836 7.637099712257798e-5 1.2642621514835567e-5 0.00015081521938657106 -9.549606812938339e-8 -2.9366274828931683e-5; 0.00014368944941152282 -0.00017663998331344865 3.676298767159612e-5 -3.8302605010625874e-5 -2.9994088966575885e-5 -1.2542112547902733e-5 9.559500009986792e-6 9.52665997527195e-6 -1.4303408155041552e-5 0.0001795175161684737 1.2205783984147047e-5 4.3095661036846206e-5 -0.00011037403199524868 6.709071423456842e-5 0.0001010666641529549 -2.9555490603611858e-5 -1.6055800718553123e-5 -1.3416061695335664e-6 4.725586078775e-5 -4.1414808568607014e-5 -9.179997126294766e-6 0.00012376896885753536 0.00010112728015583398 -3.306197648571314e-5 -9.868508239678286e-5 -5.486233165656337e-5 5.24061837647468e-5 3.6413516151428976e-5 1.7692222551599117e-5 9.17538676882965e-5 3.772040183021497e-5 -3.250444168161145e-5; -0.00013234258836094376 -1.607633089664188e-5 8.191436176927452e-5 -0.0001955039103930474 -3.5438622515639728e-6 -9.165081314934726e-5 5.25269108467613e-5 -1.022780520799643e-5 6.9989139999204e-5 -0.00017549509971361796 4.416287742712852e-5 8.635877133853432e-5 -2.301725976652226e-5 9.414789496728881e-5 3.731902074575111e-5 7.427644274837577e-6 9.206971523937362e-6 0.00011332874532597003 1.9848184135455463e-5 1.3856780467149498e-5 -2.26110523983554e-6 3.057290927896892e-5 -3.426371290773235e-5 -0.00014822056141641126 -0.0001104091777532015 -4.982415589696087e-5 -6.66665862380182e-5 -2.2595701696293586e-5 0.0001503226605993203 -5.9923610708966716e-5 7.380819556344045e-5 4.879405355238241e-5; -4.2960867333120996e-5 -4.0542517283125154e-5 -6.856353245175769e-5 2.9473440761553527e-5 -5.966521091259675e-5 0.00021194032053962957 0.00028199008268338274 4.4657593050828616e-5 2.142956039087527e-5 0.00012369567305349225 1.2835246550716224e-5 -0.00010693522305892184 1.8656756608737697e-5 3.3169143378376785e-5 1.8354285955770467e-5 -0.0001420684140358441 -8.200761764975311e-5 7.183526652433307e-5 4.333189579380714e-6 3.665956212707022e-5 8.797356606726846e-5 3.678074589473962e-7 4.6840402162971516e-5 -0.00010377459800302661 6.764196480425603e-5 -6.746135584044627e-5 2.180652229789732e-5 2.701973865250242e-5 -4.9194900541007755e-5 -0.00022031104997232172 7.067143344820065e-5 8.173610588336655e-5; -6.392078309645626e-5 -6.620668525202335e-5 0.0001224747936511388 -8.388000670589955e-6 -3.734495500894141e-5 -0.00021374071754608545 1.8072148785807977e-5 1.593111822232588e-5 1.5023899985765881e-5 -0.00013455127579718802 -0.0002356625084495059 0.00016591857613902313 -1.6301117531024486e-5 -7.315340578416515e-5 1.0990407398724776e-5 -7.478201975673558e-5 5.5443440295704775e-6 -3.243774360371616e-5 0.00013386693511065903 7.137114361194173e-6 6.0751193567659156e-5 -2.4466194382573896e-6 -4.478562251371071e-6 5.4032530650963634e-5 -0.00020119760690350222 -0.0001525255832085406 -5.5590946677608234e-5 0.00013092482888977765 -7.094326089926033e-5 7.789986168872625e-5 -8.733045680042178e-6 0.00017882958830241297; 0.00020319665571708482 -0.00013252756827011254 3.327976963749666e-5 -0.00010553357634659348 -1.6566793705231114e-5 -7.52260041862116e-5 0.00017397347045165044 -0.0001227180931108529 -5.508747045848003e-5 7.315467308774961e-5 -1.3736863651387211e-5 0.0001888861276585286 -7.855661016389314e-5 -5.5262897434537315e-5 1.0923740488701064e-5 -5.0370776365404196e-5 0.00010006964811157472 -9.94372670782415e-7 -9.411166894448964e-5 6.212682427406497e-5 7.85706805447671e-6 4.622995919465606e-5 7.2153549682583185e-6 0.0002773211869554252 -0.00010643739580142787 2.656360055528604e-5 7.348397565341098e-5 -0.00010819882510428752 -4.662423299476192e-5 1.4236554931023887e-5 5.8104176501850426e-5 -5.7965599184052036e-5; -1.4726936070559892e-5 0.00014941603653724025 7.752477317708109e-5 1.573940385141921e-5 -6.011587761449526e-5 -7.491010405929264e-5 -1.9250572632792438e-5 0.000152453108556928 -0.00024142749539451185 -0.00018178271643691732 0.00021011307903692785 -6.345589140537606e-5 -0.00028006037105214945 5.6178719650225166e-6 -0.00013171916584824654 -3.4296916465882924e-5 -0.00012775538421109362 -5.884974093632648e-5 -1.0594957030739695e-5 9.707810890987054e-5 -6.536678342952215e-5 7.44084004953924e-5 -0.00013478669502646357 -0.00012571546763055815 0.0001287961618335424 -6.119823451727068e-5 4.3945330206860316e-5 -6.358337345873391e-5 2.0127308333878576e-5 1.5128475341864487e-5 0.0001336129912932843 -2.944433658839937e-5; 0.00010350252979765613 9.909866186423218e-5 8.710903331736351e-5 3.95810593697824e-5 6.099197254730672e-5 2.5885251504194624e-5 -8.557344684261704e-6 -0.00011705543600628633 -2.0851896791311087e-5 -5.8816135591077104e-5 -0.00016661644347847784 -4.536604772805023e-5 -4.626749522070605e-5 2.9617186570945513e-5 0.00022900587920664147 0.0002517417915594387 -4.9602153906795244e-5 7.00633127205564e-5 7.663715497686984e-5 -5.9092007162006146e-5 1.4514462251756956e-5 -5.362271347915471e-5 -4.577419516176992e-6 0.0001073621997572969 3.310313970065662e-5 -5.173171742383582e-6 1.6328834150741597e-5 8.342371713003312e-5 0.0002452603103090133 0.00022050201643356563 6.0464745744652604e-5 -0.00011574812832190531; -0.00011702012482916289 0.00022354297476934548 8.813784740280542e-5 0.00017171723090226017 -0.00012719056024567627 0.00011504940272655519 -2.0780094867830772e-5 3.6491084817096624e-5 1.052202248186512e-5 8.152100423697818e-5 -0.00014199375794619697 0.0001637894493492526 -0.000160705032552384 7.980443121280788e-5 8.807176715575333e-5 -0.00015770725070384967 1.5990930856247587e-6 -0.00021022866343033887 0.00012787391102987586 -0.00012445820533090436 -3.698352150736462e-5 -3.3071541050475265e-5 5.526113824683898e-5 -3.689639555292584e-5 -0.00013486062737999208 2.459987087289093e-5 -0.0001249354644946543 5.459985556311185e-5 -0.00026866966430447647 4.136727117415296e-6 1.7269458172934278e-5 -8.625856929344469e-5; -0.00016867520997031514 5.0560173077269065e-6 -8.461078203246596e-5 -0.0001321794140311949 0.0001263168208607394 0.0002839149650026984 3.461056741806061e-5 8.829639893526142e-5 -6.505935987653805e-5 -0.0001608621266131363 -1.7266910326656968e-5 0.0001291198253239445 2.9223190104023148e-5 0.00014630153356855094 7.527718600620853e-5 0.0001598813663077475 8.485687914806288e-5 -0.00017669212200409595 9.142130631523926e-5 -4.541973492443545e-5 5.774798592696339e-7 6.319703727065925e-5 -4.5318693701046715e-5 -3.283052745238746e-5 7.53809411617862e-5 -1.6105705690203832e-5 1.2325165080330983e-5 1.8059410031203565e-5 7.162369485747222e-5 0.00011674011813746011 -0.000126574933055888 -0.00030242072778484673; -0.0001635076617542278 -0.00021130478064991434 8.443825997416045e-5 0.00020810325865637625 1.1664843141954065e-5 -7.057583509707708e-5 2.8013179572337473e-5 -4.4058396569418665e-5 0.00011749652973845718 7.407160653270857e-5 4.69545999938054e-5 -4.777589431727192e-5 9.950379232646985e-5 4.4048946320583204e-5 -0.00012847430419172975 6.855902090206841e-5 -4.541732353503837e-5 -9.812310353698561e-6 8.370648781306997e-5 -7.24997947411237e-5 0.00013563594317998596 2.3446155565363602e-5 9.556320644220375e-5 5.6949633900113966e-5 -0.00015990564072966475 -3.0128863117386424e-5 0.00012199514519194247 0.00011901738677876181 -0.00014171397136055145 -3.342982684374717e-5 -7.100201703837043e-5 -0.00012150392952138508; 8.432722042582487e-5 0.00026151811416296416 9.21455206446192e-5 -0.0001368876843150864 -4.848194348939945e-5 8.873824062560041e-5 -5.4885706598514285e-5 -2.5097368601390666e-5 5.95339073663421e-5 1.7845773432989426e-5 7.792007035842286e-5 6.339851042524246e-5 9.56178186208756e-5 -0.00012851906704215179 3.675903589849657e-5 0.00016200439076440517 -0.00011043316099292703 7.566317744588205e-5 0.000268787145065485 -0.00018029409748467304 -0.00013365958930425674 0.00018316690638697078 0.00018161720017471827 2.4268925150159635e-5 -0.00011179339399292934 -9.626667506724371e-5 -3.965884080004698e-5 3.4303713069886476e-5 -4.593433969039549e-5 -4.0442057619450325e-5 -8.425946001597971e-6 2.8092970202328888e-5; 0.00012333886098815698 -3.521646084379533e-5 7.4746142792078214e-6 -4.237554475082154e-5 0.0002497341918689264 -0.00016850860282545044 4.5045016861994805e-5 0.00017318330338922578 -5.709385385386705e-5 -5.79566842014831e-5 -5.330945367575588e-5 -9.342900028684252e-5 -9.351421630242084e-5 -6.899314047302238e-5 -0.0001393822169001127 3.3829997676665455e-5 -5.0701786846617716e-5 -4.87283179489832e-5 6.18570991838374e-5 1.5716901373821614e-5 4.8660527511963986e-5 0.00019023493206362182 0.00017565862782532916 -9.050747775006408e-5 -3.266293730913439e-6 8.947533916682845e-5 -8.11888168635696e-5 0.00012850610056660053 -2.6360663638920136e-5 -3.827428389456371e-5 0.0001178638556785961 -3.9272519813378796e-5; 7.57176823316585e-5 -0.00015946767220893077 -0.00031570740961623495 0.00011017504598558766 0.00015334803915894395 7.624711133582965e-9 3.540709817744672e-5 -8.084492569478008e-5 0.00014456626390884524 -4.853876075511972e-5 -0.00013660396493659786 -9.431498118002182e-6 0.0001570806490707662 6.551944548110344e-5 4.607317182663719e-5 0.00014523039876795238 -7.18133886476545e-5 9.874528054925413e-5 -0.00012234142909446155 -0.00021378319056392903 0.00015101845215386206 0.00013681022405936276 4.27877750115912e-5 -0.00012591888648611697 0.00012617437473810145 -7.575123877185521e-5 -4.270463964357168e-5 -5.034370756046965e-5 -4.756773691778683e-5 0.00011625204220027481 0.0001454668528384981 -0.0001436762539452429; 1.671925422938463e-5 -9.13469608990162e-6 -0.00015166355985188267 0.00016329743851470417 -8.530638257711686e-5 -6.322831865604377e-5 8.342289259685777e-5 0.0001043510132448702 -1.8902490016486155e-5 -0.00015656717693404747 5.02650549905843e-5 2.302765136205709e-5 -6.993362291454756e-5 6.746431648294065e-5 -7.578832407220607e-5 0.00020811527104611116 -9.138248384902031e-5 -1.0786075624353566e-5 7.713110283743749e-5 2.7110102048892194e-5 7.718980526346613e-5 7.385921383770777e-5 3.062708531883929e-5 -5.860241396207864e-5 -6.245234505244916e-5 -2.1997075846652566e-6 0.00016838487539003235 5.1463939992530446e-5 0.00011029384080115419 -6.975035609416196e-5 -0.0001119735384846052 -0.00010929295928421949; 6.321972717246923e-5 5.031551075349621e-5 -8.26914973631958e-5 -3.463375002778981e-5 -2.769869285570471e-5 6.0752523777025685e-6 -6.61480725002933e-5 -0.0001498126938432052 0.00010059277076921516 4.319875472988947e-5 2.3192305098703914e-5 2.522794039539549e-6 -2.53538117877204e-5 0.00022769416785282142 5.152906772396617e-5 6.357528139320258e-5 -3.836251459601497e-5 -0.00020531748202214968 7.47278473963475e-5 -2.9436548055895497e-5 -7.824152698092235e-6 0.0001326848879497973 -0.00014984157939492782 0.00012942396381818668 -0.00011811586488787957 7.803158326256166e-5 -6.322546584582819e-5 7.315052164900196e-5 6.3718908958901194e-6 -1.8846293523022865e-5 3.203010522243622e-5 -1.7267117863451447e-5; 1.7120980556806072e-5 4.541352718764832e-5 -2.9633751909257265e-5 9.491647615608805e-5 0.0001762657568189179 0.00013812908691865448 -1.6815313573586704e-5 6.431859328686586e-5 -0.0001220399270085704 5.70048916425898e-5 1.4880674099114913e-5 -4.558334585349968e-5 -2.935621960101374e-5 -0.00010720758182183899 0.00012627349606295183 9.329782296948963e-5 -5.0095178084003145e-5 -8.307905845764546e-5 2.8973253575829238e-5 -4.207185593155981e-5 5.163071655362496e-5 5.502478343363928e-5 -3.489473492560595e-5 -0.0002215594056896074 0.00012598071152855557 3.805686465160922e-5 0.00014479355850302575 0.0001305879934090098 -0.0001651441623825025 0.00012055664519824673 4.9840394423100074e-5 -3.338752030582688e-5], bias = [-1.6642940413152784e-11, -2.119324698440255e-10, 1.7084832893901097e-9, -1.982179549581815e-9, 1.8558681103132824e-9, -4.4110888129605043e-10, -2.840696572816941e-9, 1.4543271090370464e-9, -8.053217775747997e-10, -1.5819143608273394e-9, 4.5973834092913027e-10, 5.509255971632458e-10, -1.1342219658988697e-9, 1.2190383813579201e-12, -1.699562868152823e-9, -7.572806483258503e-10, 1.4955882814309778e-9, -7.337374271757467e-10, 1.177639110579588e-9, -1.6045952241117962e-9, 8.564585499141288e-10, -2.388743362263929e-9, 3.989968507572442e-9, -1.5242526104569571e-9, 8.415737130410569e-10, 2.4799261359976257e-10, 2.3859649740911802e-9, 9.74808362675086e-10, 3.7848795656406417e-10, 6.625059120247138e-10, 4.5213797221124114e-10, 2.2057075899237935e-9]), layer_4 = (weight = [-0.0007620667096548071 -0.0007181191823554438 -0.0006268600848791088 -0.0007369362131989231 -0.00044647129589662067 -0.0006868364097483228 -0.0007586670642600782 -0.0005231985865447133 -0.0005088717219663348 -0.0007900449118211369 -0.0007118599954782756 -0.0007826257349673678 -0.0007283562178949217 -0.0007246874742254523 -0.0007465963835338845 -0.0006686848792011674 -0.000661998632425312 -0.0007838723326190424 -0.0007161606157633745 -0.0007806146860567857 -0.0006613555440445958 -0.000598356544568073 -0.000630930964956839 -0.0008556952832803781 -0.0008173866404054942 -0.0006520687772608647 -0.000821077367225921 -0.0006753975590575852 -0.0007111184316123341 -0.0007775343508005341 -0.0006223193293147062 -0.0005346196187088148; 0.0002542733602888137 0.00022095421977443648 0.00015176267856910832 0.0002070354192568218 0.00039043207843661155 0.00023704283507621916 0.0003260220197130887 -4.867937512075738e-6 0.00015830243011985223 0.00019930055711584672 0.00022275390708278422 -3.6707674171703936e-5 0.00037313589182164493 0.0003123591889324337 0.00021026922844485018 0.00024503748220120033 0.0001926691962408972 0.00032616713267415794 0.0002930843065950639 0.00030746426547614895 0.00025422424499823914 0.0003171425522991815 0.0003476574601136175 0.00028216482032348525 0.00016480726604123743 0.00016702934563801267 0.00014673949108664358 0.0003962995753810359 0.0003011479338111193 0.00013693847364662823 0.0002142049173973351 0.00015782203034590342], bias = [-0.0006698424480551261, 0.00022267838382859916]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
    compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(
        ax,
        tsteps,
        waveform_nn_trained;
        marker=:circle,
        alpha=0.5,
        strokewidth=2,
        markersize=12,
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb,
    )

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
  JULIA_CPU_THREADS = 16
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 16
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.