Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux,
    ComponentArrays,
    LineSearches,
    OrdinaryDiffEqLowOrderRK,
    Optimization,
    OptimizationOptimJL,
    Printf,
    Random,
    SciMLSensitivity
using CairoMakie
Precompiling OrdinaryDiffEqLowOrderRK...
   4010.8 ms  ✓ OrdinaryDiffEqLowOrderRK
  1 dependency successfully precompiled in 4 seconds. 98 already precompiled.

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end
one2two (generic function with 1 method)

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end
soln2orbit (generic function with 2 methods)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end
d_dt (generic function with 1 method)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio  1 "mass_ratio must be <= 1"
    @assert mass_ratio  0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end
compute_waveform (generic function with 2 methods)

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defiing a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(
    Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-0.0001088696; 3.9620598f-5; -3.59188f-5; -4.802416f-5; 0.00013416172; 7.012122f-5; -3.3375854f-5; 3.2921515f-5; -1.3689365f-6; -7.938637f-5; 8.268213f-5; 0.00014754795; 8.7381144f-5; -7.796218f-5; -4.814505f-5; -3.0941083f-6; -7.920253f-6; 9.404849f-5; 5.2537984f-5; 6.66003f-5; 2.6913744f-5; -5.514224f-5; -5.242156f-5; 2.5640087f-5; -2.6342583f-5; -0.000114855015; 0.00014001878; 5.742519f-5; 0.00015158491; 0.00014264573; 5.717896f-5; -1.0874881f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[1.1879883f-5 -3.6251456f-6 -8.995461f-5 -0.000117781514 0.0001012826 8.648709f-5 0.00010011789 6.397603f-5 -4.5012303f-6 8.317105f-5 -0.00011476996 -8.538848f-5 4.199721f-5 -2.9971188f-5 -2.8108037f-5 1.4857335f-5 3.1714488f-5 -9.85532f-5 1.1640061f-5 -0.00011619366 0.00012655069 8.140016f-5 9.6310774f-5 -5.3496806f-6 1.1254644f-5 4.6110123f-5 -3.0392905f-5 0.00010739312 -1.9103963f-5 -0.00010599501 -4.3991713f-5 -0.00014564446; -4.738302f-5 6.5194006f-5 2.750009f-5 -0.00014160293 0.00012827836 3.2010215f-5 -0.00027159133 -0.00018891702 -9.632272f-5 -2.5319287f-5 2.9484045f-5 3.7653863f-5 1.362437f-5 -6.49331f-5 0.000125517 -4.9449747f-5 -5.527678f-6 6.205478f-5 -1.29448135f-5 -0.00014234654 -0.000201472 0.00013574574 -0.00022685627 -7.8283065f-6 7.3325835f-5 2.9817174f-6 -3.5905072f-5 -9.281952f-5 -3.2922606f-5 -5.1260966f-5 0.00014684764 -1.7997992f-5; -0.00027907526 -8.472346f-5 -8.8902074f-5 7.574673f-5 -5.0059284f-6 -4.3165783f-6 -3.8245103f-5 -1.4906721f-5 -4.964581f-6 3.5505407f-5 -0.00010433692 1.8702876f-5 -4.0604427f-5 0.00011851205 4.8424612f-5 1.9510737f-5 1.8543145f-5 -0.00012883895 9.118146f-5 -1.3550919f-5 1.3279089f-5 9.972147f-5 -1.2820722f-5 -5.354059f-5 -0.00018392016 0.00011963916 9.600828f-5 -4.673878f-5 -4.945662f-5 -4.8427406f-5 -0.000100708094 4.499011f-5; 0.00022776786 -1.3625908f-5 0.0001661617 -1.7991037f-5 -4.51149f-5 3.1694173f-5 5.6691002f-5 -0.00013269405 -0.00010962746 -6.213112f-5 1.4915844f-5 1.8726927f-5 6.02605f-5 3.5608741f-6 -0.00015176724 8.9578236f-5 1.7809447f-5 -0.00021597873 -4.1885116f-5 1.9228883f-5 -0.00010008241 1.8511077f-5 -1.523914f-5 2.5441768f-5 9.997402f-5 -0.000110374516 -5.468573f-5 -0.00013023955 0.00015833358 3.1011714f-5 1.260112f-5 1.6789883f-5; 3.3600896f-5 -4.765881f-5 0.000116210365 -4.325252f-5 -0.0001015085 0.00015067813 -2.2501055f-5 2.2626684f-5 -6.539062f-6 0.00015471982 -6.610092f-5 -8.155789f-5 -6.8679736f-5 7.036473f-5 2.1556589f-5 -5.165947f-5 -6.034798f-5 -2.971546f-5 -0.00023249946 4.18499f-5 2.294829f-5 -0.00018899824 -0.00012215787 -0.00010271619 3.7740032f-5 -0.00015036731 2.1419442f-5 3.162925f-5 0.00010905058 1.4783846f-5 -3.9686838f-5 -4.687296f-5; 1.8505349f-5 -4.501319f-5 0.0001639931 0.00011258385 -7.41432f-5 4.5997785f-5 5.8491885f-5 -7.8155295f-5 -0.00013067845 0.00017496076 0.00014810576 -1.7553033f-5 -0.00010706503 7.6107644f-5 3.000617f-5 7.186646f-5 3.3046115f-5 -0.00011352548 -3.30775f-5 2.3180512f-5 5.192161f-5 -0.00034195362 -0.00011846709 0.00013143467 -3.5622914f-5 -8.0004604f-5 8.054909f-6 -0.00010941066 -6.417433f-5 -6.431534f-5 4.4239387f-5 -6.528463f-5; 0.00012926503 -3.8509024f-5 7.1317234f-5 -8.082968f-5 8.280981f-5 0.00010877891 -0.00015692435 -4.850887f-5 0.00013075134 -1.9384253f-5 0.00018767004 -7.466641f-5 8.251669f-5 -9.3261806f-5 -8.419801f-6 -4.2936877f-5 9.739299f-6 -5.2664487f-5 -0.00015377345 0.00011626083 0.00018066206 -3.6804183f-6 2.5891535f-5 1.4513614f-5 -4.708554f-5 -6.162503f-5 -3.1534222f-5 -0.00018509263 7.931397f-5 6.824104f-5 -1.2750647f-5 0.00019818699; 2.1011706f-6 -9.056757f-5 4.215159f-5 -0.00015330096 -3.3173463f-5 -0.0001335461 1.1771059f-5 -5.8580466f-5 0.00014061377 1.0565735f-5 -1.48246445f-5 -6.536669f-5 6.072576f-5 -4.079722f-5 0.00024412056 -5.0488903f-5 -0.00020464994 -0.00012276511 -0.00013126744 -5.9300204f-5 0.00020215065 -7.3587755f-5 -2.0324263f-5 1.9397574f-5 0.00015824406 -9.354004f-5 -0.000100257064 0.00013379552 0.00011017077 1.8340119f-5 7.450709f-5 2.2442402f-5; 8.210317f-5 1.1130114f-5 5.4343123f-5 0.00015315072 0.000118471864 -8.464348f-6 2.5339823f-5 -3.0215877f-5 -0.00016692306 7.748309f-5 2.1852236f-7 1.8248382f-5 3.3360777f-6 -0.00010303216 -6.012163f-5 8.789765f-5 -9.790373f-5 -2.6547901f-5 4.3374574f-5 -7.336379f-5 1.6326372f-5 0.00013014994 -7.3782976f-5 -2.0194286f-5 9.824739f-5 -6.6974615f-5 -0.000140706 -5.0829156f-5 -1.0559575f-5 1.4241192f-5 2.6452546f-5 4.0778978f-5; 0.00012021398 4.262071f-5 -9.430434f-5 -7.043528f-5 -0.00013792499 -9.593399f-7 3.342546f-5 5.7612844f-5 2.8982857f-5 2.8477978f-5 0.00013105257 -6.1911414f-5 0.00012288333 -1.8310337f-5 -0.00029956776 0.00010656271 -7.82791f-5 5.3214368f-5 0.00033449737 7.526535f-5 -4.0858653f-5 2.9901283f-5 -4.1676903f-5 0.00016656231 -9.688765f-5 -6.588266f-5 9.471676f-6 0.0001258408 6.48421f-5 2.0054324f-5 -0.00017835597 4.4963068f-5; -6.585446f-5 -4.2774864f-5 -0.000119534605 -6.444842f-5 9.716938f-5 5.3752665f-6 2.0398593f-5 5.180791f-5 -2.6750966f-5 -0.00021422586 0.00010142724 5.8641395f-5 -7.558403f-5 -1.1479667f-5 6.19799f-5 0.00013103939 4.649747f-5 -3.735259f-5 2.2856571f-5 -7.802559f-5 -4.4759592f-5 -3.1698944f-6 5.8005488f-5 -1.014836f-5 -0.00024012786 -0.00017123454 -4.3101587f-5 4.2799922f-5 3.5502449f-6 1.5278018f-5 0.00022288684 2.4683804f-5; -1.4507483f-5 2.6182614f-5 0.0002018318 3.0826646f-5 -1.9663568f-5 4.5102144f-5 8.393382f-5 -7.6057535f-5 -5.7208752f-5 -6.22207f-5 -6.918613f-5 -9.27489f-5 -0.0002276319 0.00012432161 -2.9692817f-5 -0.000103274644 0.00011539985 -3.556238f-5 0.00011473314 -6.695968f-5 4.178285f-5 2.1162223f-5 -3.2705946f-5 1.2507757f-6 -4.3896584f-5 0.00015315085 8.221999f-5 0.00012202042 -0.00012239325 -7.478593f-5 -0.00012584243 -6.240874f-5; -5.1290852f-5 1.15828725f-5 -4.683849f-5 -0.00011459783 0.00012416135 3.2033902f-5 0.0001815217 0.0001617458 -0.000138307 0.00012789569 -0.0001206302 -3.4549396f-6 4.1078943f-5 7.284979f-6 0.0001604562 7.00043f-5 -9.666728f-5 7.5992466f-5 -0.0001917679 2.7427262f-5 0.00025675748 -2.7401247f-5 -0.00016025714 0.00010458802 0.00010669382 1.6276272f-5 -5.2235275f-5 -9.747994f-5 -2.6650123f-5 -6.658058f-5 -3.0860156f-5 9.372207f-5; 4.2069816f-5 -2.5027619f-5 -0.0001683741 -3.1561936f-5 0.00022355177 7.2568226f-5 0.0001728973 1.5627504f-5 -0.00011028191 9.948142f-5 9.749071f-6 -0.00010099848 1.3349814f-5 9.2976f-5 -1.670246f-5 9.4089075f-5 -0.00017422937 6.610545f-5 0.00013430194 -6.094712f-5 4.7263187f-5 -5.1304696f-6 -4.7569147f-6 0.00013098776 -8.225028f-5 1.959151f-5 4.61363f-5 1.5562457f-5 4.3646163f-5 -6.557024f-5 1.4129906f-5 -0.000120864985; 0.0002198824 -2.5154897f-5 -1.3064976f-5 3.5299618f-5 -6.658116f-5 8.028663f-6 1.4735194f-5 0.00019335486 -3.798115f-5 -5.4544642f-5 0.00014744069 0.00015593904 -0.00020126865 -4.1076237f-5 -0.00018958953 -2.6496788f-5 -2.4175366f-5 -1.2760348f-5 0.0003004939 3.1621066f-6 0.00010998854 -1.8287776f-5 -9.093854f-5 -1.8535411f-5 5.8811645f-5 1.7710563f-5 -7.779452f-6 0.00014085996 -5.091018f-5 5.1641155f-6 9.4383235f-5 0.0001704426; -4.567728f-5 4.607303f-5 -0.00019458045 4.3635848f-7 0.00016387427 3.3883178f-5 8.293043f-5 -0.00010058334 -0.00012228718 4.409956f-5 4.430266f-6 7.494708f-5 -4.2532915f-6 -0.00019235846 -7.9825455f-5 2.3075165f-5 -0.00013568757 0.00013724182 4.845768f-5 5.0094317f-5 -1.7215252f-5 9.0768626f-5 0.00015663153 0.00013575719 1.7524811f-5 -9.262917f-5 5.6308152f-5 3.089894f-5 8.6092376f-5 -1.2327305f-5 0.00012450472 -0.000115440474; -0.0001786856 -0.00018292619 0.00010924138 1.4558292f-5 -0.00020078674 9.4823285f-5 -0.00011334084 9.091475f-6 4.6557856f-5 2.8019298f-5 -9.744894f-5 -0.00025084813 -5.0959716f-5 2.6932783f-5 -0.00013050699 0.00013878594 -0.00016294289 -5.8494246f-5 2.6798698f-6 0.00012132949 -2.893342f-5 -0.00018865682 -7.824875f-6 0.00017781815 0.00012233829 0.00015868172 2.0345937f-5 -2.5072217f-5 -7.490225f-5 -0.0001687315 0.00010194074 0.00019643347; 0.00011492113 0.00010545144 3.9055252f-5 2.2089007f-5 -9.266498f-5 1.4424777f-5 2.2196708f-5 -5.9628976f-5 -0.00013968266 0.00014999385 -5.3652773f-5 6.912264f-5 0.00018028979 2.6152122f-5 6.672236f-5 -2.2048895f-5 -2.7031867f-5 -1.9723513f-5 -7.600543f-5 -8.0058475f-5 -2.2153745f-6 -6.284297f-5 -0.000118416094 0.00010537717 5.9765134f-5 0.00024229496 -9.650996f-5 0.00015040363 4.8688922f-5 -7.018384f-6 -0.00022088565 -2.6830296f-5; 6.0558683f-5 -0.00012597935 -8.867854f-5 -0.00021454903 1.474631f-5 1.0760816f-5 5.6299865f-5 9.28831f-6 9.636989f-5 0.00016247723 -0.0001449404 9.664049f-6 5.611132f-6 -0.00018473259 -0.00025712766 -4.0163195f-5 5.82073f-5 -2.0722333f-5 2.2312905f-5 7.552415f-5 3.6169065f-5 0.00012707504 -4.0311672f-5 -6.269552f-5 -0.000107482076 -0.00010238329 1.9910729f-5 -0.00019486011 0.00020167197 -0.00014451297 0.00020487118 -0.00013707425; 5.647629f-6 2.4735024f-5 -1.2270145f-5 -0.00014921043 -0.00014422476 -0.00012657605 2.3127523f-6 2.8369783f-5 -1.8171399f-5 -8.898241f-5 -5.175763f-5 -0.00011616561 9.810193f-5 -8.373661f-5 -0.00023408937 0.00018484877 -4.3595228f-5 5.823242f-5 -3.825468f-5 -2.2714306f-5 0.00015920529 -2.862546f-5 4.232121f-5 0.0001816218 -7.727001f-5 -1.949834f-5 5.114173f-5 5.8165155f-5 -0.00017561474 -0.000119985656 -5.3733656f-6 -7.0198066f-6; -1.0586289f-5 -0.00014491043 -9.8481185f-5 2.0045754f-5 -1.660785f-5 2.4566545f-5 -0.00010726694 -9.392841f-5 1.244489f-5 -0.00016569927 -0.00010259325 1.615403f-5 -0.00010461549 -0.00017926966 -0.0001294352 -5.9232887f-5 8.775403f-5 3.5669116f-5 7.9359044f-5 -4.0264615f-5 3.7531438f-5 8.1533784f-5 -0.0001200305 -0.00013829158 2.6358812f-5 9.149837f-5 -1.5047314f-6 8.21175f-5 -0.00010852697 -0.00021549274 -6.271706f-5 -2.114907f-5; -2.6559244f-5 0.00011435455 0.00010723597 4.603613f-5 -8.6000044f-5 -2.8681143f-5 9.7829885f-5 8.9098445f-5 8.9004956f-5 -8.9958405f-5 0.00018306033 0.00026260197 4.160223f-5 -6.70173f-5 -0.00017707142 -0.00013291242 2.5718951f-5 6.287228f-5 6.159456f-5 9.575527f-5 5.212029f-5 0.000118335294 9.269096f-5 -0.00012114965 -9.196527f-6 -0.00013915493 -4.1946907f-5 0.00036099556 3.2669795f-5 -4.4652224f-5 1.4608259f-5 4.9277358f-5; 0.00014270758 -0.00017709014 -0.0001262216 0.00018269989 -9.0951675f-5 5.9323025f-5 -0.00019651974 9.404158f-6 -1.8269755f-5 0.000105707935 -7.323874f-5 -0.00013007561 7.960047f-5 1.4624493f-5 -0.00019440478 1.9008511f-6 -2.6862861f-5 0.00012108779 0.0002272329 1.1853146f-5 0.000100623394 9.2125365f-6 -0.00020513043 -3.7115453f-5 0.00013770282 0.0001219753 2.8102788f-5 6.419057f-5 5.4472497f-5 4.290732f-5 5.279092f-5 -2.1062708f-5; 2.9200166f-5 5.2191426f-5 -2.4992996f-5 -9.5035866f-5 7.3438823f-6 0.00012873414 -0.00021149739 -0.00011748448 1.9758077f-6 3.732704f-5 -0.0001022762 0.00010412815 -6.0697013f-5 -8.911499f-5 4.0786585f-5 2.5250023f-5 5.503898f-5 -5.3229796f-5 -3.887135f-5 -0.00016950864 -2.068698f-6 -5.070305f-5 -8.536442f-6 -0.00013059603 0.00013167813 2.555345f-5 2.1136477f-5 2.257816f-6 0.00013910599 6.268321f-5 7.044266f-6 -0.000109633715; 3.8871527f-5 2.109028f-5 0.00017585655 -0.00010564411 2.2741855f-5 -8.597654f-6 3.1963232f-6 7.256244f-5 -0.00021954605 0.00013220879 -6.041582f-5 4.8494563f-5 0.00013686872 0.000113527116 6.958946f-5 -2.2055496f-5 0.00014214525 6.8999856f-5 6.308844f-5 0.00017192798 -5.130963f-5 -2.136369f-5 -9.343573f-5 6.931803f-5 0.0001149582 4.2167765f-5 -9.6591306f-5 -7.1534465f-5 0.00011217571 6.9733505f-5 0.00014824467 4.858496f-5; 3.60759f-5 0.00014882472 -0.00012182419 -0.000113395705 6.378534f-6 2.2445234f-5 0.0001392355 5.1584597f-5 -1.3277809f-5 -1.1359299f-5 0.00017081149 4.859794f-5 0.00012251313 4.6043737f-5 -6.95165f-7 -1.5435275f-5 -0.00022623788 0.000166514 -5.1400217f-5 0.000113654394 7.320088f-5 -9.9636585f-5 2.0414402f-6 1.8846531f-5 -0.00010503664 8.13235f-5 -0.00013421314 -0.0002233903 -1.8435607f-5 8.780702f-5 -5.4013548f-5 0.00011228046; 2.006949f-5 -3.3089556f-5 0.00010540472 2.2008783f-6 0.000203296 -0.00011609201 -1.5843907f-5 -0.000174467 -7.674052f-5 4.3832042f-5 -0.00013136269 7.927016f-5 -6.0624657f-6 -0.00010338297 8.798077f-5 -0.00014990171 -1.6125317f-5 1.3553393f-5 0.00013767493 -1.6357064f-5 -0.00020217456 -3.6361645f-5 0.000106441505 -0.00014636082 -1.2908263f-5 -4.7991754f-5 1.3600739f-5 -0.00010369663 9.128798f-5 -9.469381f-5 -3.3520585f-6 -0.00025586342; -0.00014261958 1.7641959f-6 -0.00017414724 -1.8251483f-5 1.7750295f-6 0.0001827489 0.00020135977 -9.970825f-5 -9.122727f-6 -0.00010858355 -0.00010203087 5.4493754f-5 9.06313f-5 9.046483f-6 1.1927491f-5 9.1942566f-5 -3.0314866f-5 1.7198086f-5 -5.780126f-5 3.585185f-5 -8.841037f-5 -5.9961716f-5 -8.773105f-5 -0.00018465875 -2.1057118f-5 0.00012856603 2.3083898f-5 -8.401503f-5 0.00014212032 -4.3474747f-6 5.8335914f-5 0.00013789372; 8.132269f-5 -4.2709424f-5 2.8377757f-5 -0.00014452817 0.00017748935 -5.038588f-5 -0.00010458344 -9.498932f-5 -1.145067f-5 0.00015961404 6.330031f-5 0.00017904086 -3.528378f-5 9.021159f-6 4.5495886f-5 -4.6467416f-5 5.8766258f-5 -7.016516f-5 0.00019826124 -3.287245f-5 -3.9746195f-5 -2.4441482f-5 2.2555701f-5 0.0001110859 5.2772073f-5 7.9329584f-5 2.0032508f-5 -6.6457775f-5 -7.002792f-6 0.00011820125 9.6961034f-5 1.890566f-5; -8.486009f-5 -7.737132f-5 7.471828f-8 0.000121427125 1.6006281f-5 6.998195f-5 0.00023448917 8.796543f-5 -8.548843f-5 -7.469259f-5 5.5868215f-5 -0.00010647489 -0.00018843728 -5.5975834f-5 -7.476657f-5 0.00013939914 -9.134879f-5 -6.4718224f-5 0.000105731364 3.610479f-6 -1.5299345f-5 -0.00018568068 -7.318354f-5 0.00027632166 0.00012021807 5.1703068f-5 -0.00015020651 8.3905514f-5 0.00010316696 -3.606991f-5 -0.00012002322 -9.556076f-5; -3.998852f-5 5.2774692f-5 -0.00016444518 6.52413f-5 1.9951676f-5 0.00011362288 0.00013464814 -3.9794075f-5 -0.00010163501 -1.7881948f-5 -1.4309177f-5 -7.552543f-5 -4.6880006f-5 -1.8024142f-5 0.00010681851 -2.8325019f-6 -6.887637f-5 -0.00011816295 -0.00013091834 -9.027047f-5 8.251078f-7 -9.6152355f-5 -8.74265f-5 3.556765f-5 0.00010680221 2.5876518f-5 -0.00018464295 -6.3273255f-5 4.5984752f-6 -9.950512f-5 -5.4199216f-5 3.3272598f-5; -0.00022728012 0.0001752915 0.00016761143 -5.8250956f-5 -4.597111f-5 0.00017409165 1.9714265f-5 1.8470922f-5 7.250013f-5 0.00016148861 0.00015986353 0.000101119134 -0.00016122306 -7.106707f-5 -0.000107560976 1.815542f-5 -7.950301f-5 -3.4698445f-5 0.00016073228 -6.9706686f-5 0.00011083724 2.6666705f-5 5.3909058f-5 3.514501f-7 -4.3301858f-5 5.697248f-5 -0.0001152433 9.781944f-5 6.0724848f-5 -0.00013642793 8.2877596f-5 1.4836322f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[-0.00010993963 9.4818715f-5 0.00012867052 -0.00016056378 -3.09458f-5 -5.7810077f-5 7.369396f-6 0.00010720569 -7.5375945f-5 4.1188938f-5 -3.887614f-6 0.00024769266 -0.00010468148 -1.3846774f-5 1.1111532f-5 -7.813644f-5 6.917941f-5 1.9713696f-6 3.9751405f-5 7.227133f-5 9.7936885f-5 -0.0001844626 -9.3200644f-5 -0.00016031312 7.9482605f-5 6.3070795f-5 -5.4387154f-5 -6.356012f-5 -2.5527108f-6 -1.9034385f-5 8.1940736f-5 0.00011765598; -0.00012862429 2.730866f-6 -8.435519f-5 9.776624f-5 -9.375284f-6 -5.9473332f-5 4.2943233f-5 -6.6416695f-5 -0.0001498843 -5.8007867f-5 7.978445f-5 8.42199f-5 6.1929466f-5 -1.2474613f-5 6.472484f-5 -0.0001414037 1.3270919f-5 -0.00010545542 1.8168863f-5 3.7965605f-5 -9.658567f-5 -0.00013355011 -0.00017869037 -4.0648305f-5 -1.4269099f-5 4.8374815f-5 9.745105f-5 0.00011293062 0.00016775556 4.768675f-5 6.931326f-5 -0.00011703337], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most DL frameworks, Lux defaults to using Float32, however, in this case we need Float64

julia
const params = ComponentArray(f64(ps))

const nn_model = StatefulLuxLayer{true}(nn, nothing, st)
StatefulLuxLayer{true}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),  # 64 parameters
        layer_3 = Dense(32 => 32, cos),  # 1_056 parameters
        layer_4 = Dense(32 => 2),       # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"];
        position=:lb,
    )

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end
loss (generic function with 1 method)

Warmup the loss function

julia
loss(params)
0.0007134816960284321

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end
callback (generic function with 1 method)

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob,
    BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback,
    maxiters=1000,
)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-0.0001088696008080962; 3.962059781763711e-5; -3.591880158634714e-5; -4.8024161515053535e-5; 0.0001341617171419447; 7.012121932349089e-5; -3.337585440017773e-5; 3.2921514502888446e-5; -1.3689365232498016e-6; -7.938637281758e-5; 8.268212695817819e-5; 0.00014754795120068917; 8.738114411234437e-5; -7.796217687419333e-5; -4.814505155084293e-5; -3.0941082513898606e-6; -7.92025275586712e-6; 9.404848969996541e-5; 5.253798371985523e-5; 6.660030339834968e-5; 2.6913743567970998e-5; -5.514224176292144e-5; -5.2421561122130875e-5; 2.5640087187609456e-5; -2.634258271427083e-5; -0.00011485501454443145; 0.0001400187757099751; 5.7425189879636564e-5; 0.00015158491441969856; 0.0001426457311023592; 5.717896056010922e-5; -1.0874880899774115e-5;;], bias = [-1.7131383405654126e-16, 7.477762306986586e-17, -3.341300879725104e-17, -6.833184137503479e-18, 3.6917834631988665e-17, -9.737648301572267e-18, -4.814638273777934e-17, 8.091114368057854e-17, 4.196752260085406e-19, -1.1313828585622747e-16, 2.408137628516539e-16, 3.291252268347281e-16, 1.118670647195323e-16, -9.763445290290664e-17, 2.901682167521196e-17, 8.72253701929659e-19, 9.353601016840086e-19, 1.2220007192031072e-16, 1.1726142592053384e-16, 1.0927178167809086e-16, 3.571578988590798e-17, -3.483754089772401e-17, -4.8832668682074235e-17, 3.405134328745977e-17, -8.388581121717956e-18, -1.509098767465593e-17, -8.145394797749911e-17, 1.0085616293595527e-16, 1.1602510545116997e-16, 3.568011000764439e-16, 1.5098013750328756e-17, -1.9903767757250583e-17]), layer_3 = (weight = [1.1879790607259679e-5 -3.62523755190906e-6 -8.995469860726573e-5 -0.00011778160577866785 0.00010128251133958321 8.64869970502931e-5 0.00010011779787257707 6.397594089160135e-5 -4.501322193186082e-6 8.317095753975622e-5 -0.0001147700505423691 -8.53885695430777e-5 4.199711739440862e-5 -2.9971280211406394e-5 -2.8108129383121366e-5 1.4857243430220236e-5 3.171439576201532e-5 -9.855329434847043e-5 1.163996867783301e-5 -0.00011619375172065167 0.00012655059518531356 8.140006949199762e-5 9.63106821547861e-5 -5.3497725200706945e-6 1.1254552108062631e-5 4.611003069678752e-5 -3.0392996533713465e-5 0.00010739302789099914 -1.910405543801381e-5 -0.00010599510014051122 -4.399180497036847e-5 -0.00014564455076813306; -4.738641525870277e-5 6.519060973407408e-5 2.749669276856492e-5 -0.00014160632288219623 0.00012827496668655085 3.200681878699998e-5 -0.0002715947253596631 -0.00018892042005570954 -9.632611779632516e-5 -2.532268315514786e-5 2.9480648140113884e-5 3.765046624216326e-5 1.3620973208552727e-5 -6.493649432338571e-5 0.00012551360980471383 -4.945314374306706e-5 -5.531074362425081e-6 6.205138326952742e-5 -1.2948210110677634e-5 -0.00014234994030227864 -0.00020147539050920324 0.00013574234019363162 -0.00022685966412207278 -7.831703056198767e-6 7.332243824241291e-5 2.978320866928558e-6 -3.590846843104042e-5 -9.282291882807492e-5 -3.292600247327963e-5 -5.12643624238585e-5 0.00014684424370391187 -1.800138904907729e-5; -0.0002790773367298147 -8.47255367109001e-5 -8.890414827255628e-5 7.574465345606182e-5 -5.008002775125403e-6 -4.318652630167456e-6 -3.8247177856826355e-5 -1.490879529519348e-5 -4.966655327005496e-6 3.550333278035987e-5 -0.00010433899368144293 1.8700801971879564e-5 -4.060650170066995e-5 0.00011850997472429874 4.842253767201652e-5 1.950866245970415e-5 1.8541071055371406e-5 -0.0001288410210538055 9.11793824496307e-5 -1.3552993353364288e-5 1.3277014856432635e-5 9.971939218012639e-5 -1.2822796437100012e-5 -5.3542663211553794e-5 -0.00018392223208429424 0.00011963708604235908 9.600620907514561e-5 -4.6740854289233284e-5 -4.94586936685493e-5 -4.8429480360194364e-5 -0.00010071016802903225 4.498803465758277e-5; 0.00022776734464225137 -1.3626420821476976e-5 0.0001661611857906037 -1.7991549737583036e-5 -4.511541204754382e-5 3.169366015010172e-5 5.6690489020017576e-5 -0.00013269456770847955 -0.00010962797035807236 -6.213163227419686e-5 1.491533099562553e-5 1.8726413929308215e-5 6.0259986514045706e-5 3.560361039324482e-6 -0.00015176774844638457 8.95777228486775e-5 1.780893386405004e-5 -0.00021597923854477043 -4.18856290145256e-5 1.9228369512181964e-5 -0.00010008292532637196 1.8510563551733268e-5 -1.5239653419651136e-5 2.5441255149250448e-5 9.99735092244163e-5 -0.00011037502930610642 -5.46862430730285e-5 -0.00013024006706308008 0.00015833307030499875 3.101120079174558e-5 1.2600606952962507e-5 1.6789370294479163e-5; 3.359844641167033e-5 -4.766125993698435e-5 0.00011620791565731545 -4.325496910966473e-5 -0.00010151094965763535 0.0001506756766546771 -2.25035050745365e-5 2.262423413852127e-5 -6.541511534278549e-6 0.00015471736924595048 -6.610336776717675e-5 -8.15603393631599e-5 -6.868218542437115e-5 7.036228135333226e-5 2.155413904332086e-5 -5.16619197652927e-5 -6.035042815271485e-5 -2.971790990628714e-5 -0.00023250190633477145 4.184745057940184e-5 2.2945840560014904e-5 -0.00018900068743748068 -0.0001221603210935133 -0.00010271864220628188 3.773758226528163e-5 -0.00015036976173647278 2.1416992699275554e-5 3.162680014141444e-5 0.00010904813325837508 1.4781396662190515e-5 -3.9689287855405574e-5 -4.687540922092921e-5; 1.8504208955961645e-5 -4.5014328034201385e-5 0.0001639919600255528 0.0001125827138001775 -7.414434102440358e-5 4.5996645812805385e-5 5.8490745604896994e-5 -7.81564350201077e-5 -0.00013067959397642718 0.00017495962211785295 0.0001481046155366808 -1.7554172782292263e-5 -0.00010706617203152301 7.610650473895862e-5 3.0005029575361476e-5 7.186531994181842e-5 3.3044975598269756e-5 -0.00011352661848843547 -3.307863979910244e-5 2.3179371882973634e-5 5.192046856779785e-5 -0.0003419547620889052 -0.00011846823062126269 0.00013143353399743503 -3.562405353489153e-5 -8.000574333902734e-5 8.053769224606908e-6 -0.00010941180252314268 -6.417547146167527e-5 -6.431647952023915e-5 4.4238246946291024e-5 -6.528576804169294e-5; 0.00012926652610688633 -3.850752497170384e-5 7.131873329557229e-5 -8.08281833663054e-5 8.281131207398005e-5 0.00010878041260797105 -0.0001569228474757099 -4.850737192397245e-5 0.00013075284417640235 -1.938275405633758e-5 0.00018767154009468335 -7.466490884623812e-5 8.251819284554013e-5 -9.326030655144266e-5 -8.418301570839914e-6 -4.293537809941172e-5 9.740797883512427e-6 -5.2662987821592616e-5 -0.00015377195037514295 0.00011626233075433938 0.0001806635559994333 -3.6789189955596525e-6 2.589303452326596e-5 1.4515113013419472e-5 -4.708404181952809e-5 -6.162353335027484e-5 -3.15327228811591e-5 -0.000185091135302413 7.931546909887893e-5 6.824253845181994e-5 -1.2749147993657937e-5 0.0001981884857166167; 2.1003697048804955e-6 -9.056837060312172e-5 4.215078971641699e-5 -0.0001533017644641024 -3.3174263614322345e-5 -0.00013354689925732367 1.1770257891058903e-5 -5.858126731878521e-5 0.0001406129735306796 1.0564933666631048e-5 -1.4825445384173586e-5 -6.536748750191019e-5 6.07249598152915e-5 -4.0798019278424456e-5 0.00024411975701981811 -5.0489703959882564e-5 -0.00020465074393256037 -0.00012276590913357386 -0.00013126824395554473 -5.930100504597937e-5 0.00020214985071392348 -7.358855563039285e-5 -2.0325064348851855e-5 1.9396772962082966e-5 0.00015824325911411634 -9.354083939511391e-5 -0.0001002578652213317 0.00013379471730617487 0.00011016996669527336 1.833931802425193e-5 7.450628564869667e-5 2.2441601152551083e-5; 8.210345810288826e-5 1.1130400584758117e-5 5.434340970600385e-5 0.000153151001418106 0.00011847215033844318 -8.464061588942065e-6 2.5340109336635924e-5 -3.0215590536730466e-5 -0.00016692276870822165 7.748337781299219e-5 2.1880872713681988e-7 1.824866818855687e-5 3.3363640929980806e-6 -0.00010303187185423133 -6.012134429989032e-5 8.789793615937607e-5 -9.790344223370402e-5 -2.6547614775385596e-5 4.337485994687651e-5 -7.336350712217061e-5 1.6326658501153748e-5 0.00013015022965623404 -7.378268959223875e-5 -2.019399938513158e-5 9.824767299937451e-5 -6.697432867008942e-5 -0.00014070571809066275 -5.082886958485127e-5 -1.0559288180280291e-5 1.424147818829116e-5 2.6452832448509224e-5 4.077926393142341e-5; 0.00012021564948932032 4.262237821800026e-5 -9.430267486626875e-5 -7.043361215059361e-5 -0.00013792332133184752 -9.576732926181715e-7 3.3427124848602096e-5 5.761451070989831e-5 2.8984523354810087e-5 2.847964465595848e-5 0.00013105423609354567 -6.190974777005454e-5 0.0001228849946111533 -1.830866999247964e-5 -0.00029956609755194463 0.00010656437412231637 -7.827743642869002e-5 5.3216034261052715e-5 0.0003344990331735581 7.52670170678867e-5 -4.085698633011062e-5 2.9902949294552775e-5 -4.167523688542397e-5 0.00016656397970487733 -9.688598178931142e-5 -6.588099453970168e-5 9.473342701446232e-6 0.000125842467654508 6.484376395214443e-5 2.0055990220631966e-5 -0.00017835430439467316 4.496473452869273e-5; -6.585559317366168e-5 -4.277599732498037e-5 -0.00011953573853000405 -6.444955074451232e-5 9.716824711310959e-5 5.374132844254188e-6 2.0397459185431975e-5 5.180677685663688e-5 -2.6752099482593273e-5 -0.00021422699360166831 0.00010142610840509514 5.864026166043664e-5 -7.558516376641683e-5 -1.1480800320468131e-5 6.197876932807897e-5 0.00013103825181266836 4.649633551841217e-5 -3.735372451822332e-5 2.2855437738571265e-5 -7.802672048322145e-5 -4.476072576702002e-5 -3.1710280368591873e-6 5.800435387888637e-5 -1.0149493799446058e-5 -0.000240128995254537 -0.00017123567381397719 -4.310272056365977e-5 4.2798788446334337e-5 3.549111217105527e-6 1.5276883961156894e-5 0.00022288570612678967 2.4682670802093006e-5; -1.4508119017362818e-5 2.6181977267014307e-5 0.00020183116834527853 3.082601000937693e-5 -1.9664204552779795e-5 4.5101507970002684e-5 8.39331809795558e-5 -7.605817114085857e-5 -5.720938821173261e-5 -6.222133702681259e-5 -6.918676404598017e-5 -9.274953804992611e-5 -0.00022763253542211668 0.00012432097395799257 -2.9693453649770796e-5 -0.0001032752803046162 0.00011539921655935248 -3.556301776984513e-5 0.00011473250601125019 -6.696031374008291e-5 4.178221246843254e-5 2.116158652211149e-5 -3.270658231190337e-5 1.250139410557425e-6 -4.389721976326752e-5 0.00015315020977436849 8.22193509914116e-5 0.0001220197850154814 -0.000122393884845902 -7.478656658053897e-5 -0.000125843067104825 -6.240937687539482e-5; -5.128938366835928e-5 1.15843406776366e-5 -4.6837021477422155e-5 -0.00011459636309593446 0.0001241628181334556 3.20353704184328e-5 0.00018152316467892905 0.00016174727195464599 -0.0001383055366430374 0.00012789715972319523 -0.00012062872858661907 -3.453471413041161e-6 4.1080411642885095e-5 7.286447066851502e-6 0.00016045766667519232 7.000577013530418e-5 -9.666581329367477e-5 7.599393418347618e-5 -0.000191766435568589 2.7428730641239258e-5 0.00025675895292368337 -2.7399779105345102e-5 -0.00016025567467733772 0.0001045894880691598 0.00010669528844589915 1.6277739882910995e-5 -5.2233806604662836e-5 -9.747847227556209e-5 -2.664865471600152e-5 -6.657911010263888e-5 -3.0858687500604855e-5 9.372353658835049e-5; 4.2071367512735614e-5 -2.5026067508682913e-5 -0.0001683725479081694 -3.156038494426359e-5 0.0002235533263744281 7.256977712740457e-5 0.0001728988516121683 1.562905506643461e-5 -0.0001102803564420384 9.948296954161684e-5 9.750622463519042e-6 -0.00010099692574232474 1.3351365648025065e-5 9.297755039069079e-5 -1.67009077115876e-5 9.409062638651164e-5 -0.00017422781659052198 6.610700236405314e-5 0.00013430349079198578 -6.0945569405981235e-5 4.726473869295338e-5 -5.128918222276146e-6 -4.7553632824062055e-6 0.00013098930665064151 -8.224872565124634e-5 1.959306225825413e-5 4.6137852929574315e-5 1.556400800536269e-5 4.364771464377926e-5 -6.556868740818978e-5 1.4131457443631952e-5 -0.00012086343324272384; 0.00021988558198916144 -2.515170838033759e-5 -1.3061787124217446e-5 3.5302806443166364e-5 -6.657797150422182e-5 8.03185155508246e-6 1.4738382721942143e-5 0.00019335805170829948 -3.797796109448468e-5 -5.454145304849378e-5 0.0001474438778835808 0.00015594222548076194 -0.000201265456619348 -4.1073047962179236e-5 -0.0001895863368396141 -2.649359869120458e-5 -2.4172177129558364e-5 -1.27571591498241e-5 0.00030049709510184257 3.165295472521819e-6 0.00010999173007566123 -1.8284586824496666e-5 -9.093535283919991e-5 -1.8532222222903512e-5 5.8814834280736945e-5 1.7713751892528692e-5 -7.776263212356687e-6 0.00014086315257135155 -5.090699216692592e-5 5.167304313878248e-6 9.438642353178312e-5 0.0001704457846948475; -4.5676101436201806e-5 4.607420792983648e-5 -0.00019457927614037132 4.375365989963851e-7 0.00016387544688715704 3.3884355943358036e-5 8.293161168642162e-5 -0.0001005821619770135 -0.00012228600157011433 4.4100739032203274e-5 4.431444139683823e-6 7.494825636914892e-5 -4.252113413329572e-6 -0.00019235728599649297 -7.982427690929028e-5 2.3076342913032073e-5 -0.00013568639464263009 0.00013724299729749083 4.845885938213896e-5 5.00954950979787e-5 -1.7214074105663932e-5 9.07698043733251e-5 0.00015663271129544586 0.00013575836725015033 1.7525989605405602e-5 -9.262799052574384e-5 5.630933024751065e-5 3.0900118289922534e-5 8.60935536906394e-5 -1.2326126707724621e-5 0.0001245058950730015 -0.00011543929635133737; -0.00017868783792874895 -0.00018292842609737007 0.00010923914347147612 1.4556054752737992e-5 -0.00020078898118736828 9.482104765555198e-5 -0.00011334307377621501 9.089237623974028e-6 4.6555619023981926e-5 2.8017061099656282e-5 -9.745117453787028e-5 -0.00025085036354067125 -5.096195267907771e-5 2.6930545987108916e-5 -0.00013050922602344189 0.00013878370050877283 -0.0001629451249253382 -5.8496483274953086e-5 2.677632851536956e-6 0.00012132725210465764 -2.8935657196436055e-5 -0.00018865905762579125 -7.82711189896883e-6 0.00017781590982411056 0.00012233604907594623 0.00015867948646263198 2.0343699775542876e-5 -2.5074453828728728e-5 -7.490448801473022e-5 -0.00016873373128401548 0.00010193850580906682 0.00019643122910127002; 0.00011492239006693742 0.00010545270212825147 3.905651116875702e-5 2.2090265856978563e-5 -9.266372201600095e-5 1.4426035859660185e-5 2.21979664005736e-5 -5.9627716681996704e-5 -0.00013968139667630376 0.00014999511056012973 -5.365151430814052e-5 6.912389669232863e-5 0.0001802910484642859 2.6153380686454875e-5 6.672361648307843e-5 -2.2047635711340144e-5 -2.7030608025222947e-5 -1.9722254209765968e-5 -7.60041728695697e-5 -8.00572159547208e-5 -2.2141155837636608e-6 -6.28417108869955e-5 -0.00011841483485985835 0.00010537842915292626 5.9766393106121375e-5 0.00024229622083149133 -9.650870181721685e-5 0.0001504048924929617 4.8690181002167976e-5 -7.01712527151997e-6 -0.00022088438693534896 -2.682903671446904e-5; 6.055588009752945e-5 -0.00012598215262420805 -8.868134473465737e-5 -0.00021455183168888353 1.4743506924286589e-5 1.0758012920119587e-5 5.629706201710369e-5 9.285507346641746e-6 9.63670884606754e-5 0.00016247442390289262 -0.0001449431963033586 9.661245983309406e-6 5.608329317909609e-6 -0.00018473539394341324 -0.0002571304591606956 -4.0165997872981286e-5 5.820449796140661e-5 -2.0725135711982263e-5 2.231010175585994e-5 7.552134620476447e-5 3.616626263231567e-5 0.0001270722334970142 -4.031447470203261e-5 -6.269832529375743e-5 -0.00010748487848338149 -0.00010238609111416347 1.9907926159658644e-5 -0.00019486291576191952 0.00020166916355943073 -0.00014451577744925865 0.00020486837301853872 -0.00013707705666326448; 5.6449204882470055e-6 2.473231480934244e-5 -1.2272853641945124e-5 -0.00014921314349095218 -0.0001442274681588019 -0.00012657875985069784 2.3100435234374645e-6 2.8367074005189443e-5 -1.8174107285915797e-5 -8.898511649504763e-5 -5.1760338724585665e-5 -0.00011616831969630933 9.809922422818497e-5 -8.373931840224673e-5 -0.00023409207850391857 0.00018484606490773062 -4.359793664235703e-5 5.822971226252379e-5 -3.8257387391522815e-5 -2.2717015131373628e-5 0.00015920257802889656 -2.8628168625173206e-5 4.23185026326698e-5 0.00018161909039434947 -7.727271648519698e-5 -1.9501049141854436e-5 5.113902235707981e-5 5.8162446034375765e-5 -0.00017561744815365333 -0.000119988364790782 -5.376074373370807e-6 -7.022515351593787e-6; -1.0591727908969376e-5 -0.0001449158698907662 -9.848662375058836e-5 2.0040315581200127e-5 -1.6613288602583893e-5 2.4561106334508748e-5 -0.00010727237894950322 -9.393384606862333e-5 1.2439450989245976e-5 -0.00016570470688217223 -0.0001025986894038953 1.614859139032689e-5 -0.00010462092540743431 -0.00017927509542838612 -0.0001294406356411582 -5.923832577006976e-5 8.774859089119287e-5 3.566367741390674e-5 7.935360554786799e-5 -4.027005342701686e-5 3.7525998783009536e-5 8.152834562300223e-5 -0.00012003593655587708 -0.0001382970185606365 2.6353372964719672e-5 9.149292964671185e-5 -1.5101701311654117e-6 8.211205932260055e-5 -0.00010853240746121613 -0.00021549817392418387 -6.27224981091495e-5 -2.115450842591068e-5; -2.6555198969929585e-5 0.00011435859774307767 0.00010724001181613182 4.604017517946984e-5 -8.599599827480399e-5 -2.8677097402881127e-5 9.783393029890505e-5 8.910249012358041e-5 8.900900134419722e-5 -8.99543593235055e-5 0.00018306437410016267 0.0002626060158533098 4.1606275730976735e-5 -6.701325769002725e-5 -0.00017706736984594404 -0.00013290837527649783 2.5722996683429618e-5 6.287632273111879e-5 6.159860636641533e-5 9.575931467649568e-5 5.212433457443092e-5 0.00011833933963741401 9.26950087474965e-5 -0.00012114560381208005 -9.19248127077169e-6 -0.00013915088497501053 -4.194286161269168e-5 0.0003609996055771896 3.2673840864518704e-5 -4.4648179068967896e-5 1.461230472584218e-5 4.928140329382064e-5; 0.00014270865874096032 -0.00017708906155348183 -0.00012622051356628883 0.00018270097431923813 -9.09505927938875e-5 5.932410719652189e-5 -0.00019651866235107526 9.405239934501142e-6 -1.8268672947663064e-5 0.00010570901728897766 -7.323765960790755e-5 -0.00013007453010686809 7.960154986904369e-5 1.4625574904134967e-5 -0.00019440370154370006 1.9019330875449732e-6 -2.68617789788481e-5 0.00012108887101021918 0.00022723398221388062 1.185422813251161e-5 0.00010062447624478168 9.213618495783915e-6 -0.0002051293507338355 -3.711437087076106e-5 0.00013770390584796343 0.00012197638504403813 2.8103869835831583e-5 6.419165371881383e-5 5.447357913908373e-5 4.290840085559441e-5 5.279200347104628e-5 -2.106162567542236e-5; 2.9198630169105873e-5 5.218989043538142e-5 -2.499453204467757e-5 -9.503740203758046e-5 7.342346531064712e-6 0.00012873260801466967 -0.00021149892619536594 -0.00011748601592802782 1.9742718778220042e-6 3.732550556929704e-5 -0.0001022777365632826 0.00010412661448396513 -6.069854918126958e-5 -8.911652604541832e-5 4.078504879185652e-5 2.524848731450807e-5 5.5037445587055556e-5 -5.323133210726031e-5 -3.887288459185767e-5 -0.0001695101751592493 -2.0702337805625097e-6 -5.070458484073412e-5 -8.537977568214069e-6 -0.00013059756803517593 0.0001316765914326032 2.5551914755909604e-5 2.1134941003826792e-5 2.2562802642444506e-6 0.0001391044564898668 6.26816757958911e-5 7.042730327210178e-6 -0.0001096352503863668; 3.887615011873592e-5 2.1094902966735825e-5 0.0001758611716456771 -0.00010563948753885484 2.2746478042522826e-5 -8.59303136282784e-6 3.2009462290563794e-6 7.256706439760208e-5 -0.00021954143167137312 0.00013221341496202937 -6.0411196410565236e-5 4.849918614008772e-5 0.00013687334567177682 0.00011353173895142847 6.959408628857202e-5 -2.2050872668173676e-5 0.00014214987013359403 6.900447906329688e-5 6.309306181604415e-5 0.00017193260564338534 -5.1305008038312845e-5 -2.135906734226406e-5 -9.343110732631383e-5 6.932264941381981e-5 0.00011496282522669746 4.2172388110623855e-5 -9.658668286765293e-5 -7.15298418993101e-5 0.0001121803316879751 6.97381284214329e-5 0.00014824929084964397 4.85895826375189e-5; 3.607695102174462e-5 0.00014882577228774366 -0.00012182313808523991 -0.00011339465423305496 6.379584777379149e-6 2.2446284804678788e-5 0.00013923655491736785 5.158564740042267e-5 -1.3276758051581132e-5 -1.135824809979063e-5 0.00017081253749267066 4.8598990491072724e-5 0.00012251417787694212 4.604478739849185e-5 -6.941144027977803e-7 -1.54342240938441e-5 -0.00022623683331119727 0.00016651505133551348 -5.139916616543188e-5 0.00011365544482315839 7.3201928119595e-5 -9.963553462228742e-5 2.0424908080866634e-6 1.8847581616789894e-5 -0.00010503559029246602 8.132455254154862e-5 -0.00013421209301385666 -0.00022338925633165065 -1.8434556165530686e-5 8.780807106005366e-5 -5.401249732756998e-5 0.00011228151119495841; 2.00661795683836e-5 -3.309286716060646e-5 0.0001054014062542286 2.1975674921737237e-6 0.0002032926929682492 -0.00011609531815295605 -1.5847217800486643e-5 -0.0001744703115396386 -7.674383254155617e-5 4.382873135952862e-5 -0.00013136599613019557 7.926685044411973e-5 -6.0657764511471015e-6 -0.00010338627929698118 8.797745955277383e-5 -0.0001499050197158205 -1.6128628194142157e-5 1.3550082040160642e-5 0.0001376716170493347 -1.6360374720106757e-5 -0.00020217787117829568 -3.63649562491733e-5 0.00010643819383446324 -0.00014636413128713173 -1.2911573539471989e-5 -4.799506517840148e-5 1.3597428515357237e-5 -0.00010369993855377403 9.128466605078962e-5 -9.469711971133703e-5 -3.355369235151825e-6 -0.000255866725855899; -0.0001426199123282614 1.763864854803326e-6 -0.0001741475679930939 -1.8251814214925824e-5 1.7746985283189571e-6 0.000182748572153907 0.00020135943854500662 -9.970858441858436e-5 -9.123057641153083e-6 -0.00010858387755189847 -0.00010203120105626811 5.44934228334424e-5 9.063096880953018e-5 9.046151756303719e-6 1.1927159942098817e-5 9.194223461479985e-5 -3.0315197325142883e-5 1.7197755407944488e-5 -5.780158968020094e-5 3.5851517207158415e-5 -8.841070298996641e-5 -5.9962047050538005e-5 -8.773138320731795e-5 -0.00018465908575064386 -2.1057448936806775e-5 0.0001285656945062174 2.3083566741373804e-5 -8.401536247885643e-5 0.00014211998863418942 -4.34780573170071e-6 5.833558320108012e-5 0.00013789338485610783; 8.132566140281e-5 -4.270644935879818e-5 2.838073155451166e-5 -0.0001445251924915522 0.00017749232060755896 -5.0382904783899996e-5 -0.00010458046895203097 -9.498634758622298e-5 -1.1447695806986134e-5 0.00015961700987551867 6.33032817395077e-5 0.00017904383125728858 -3.528080710660289e-5 9.024133168140601e-6 4.549886066249537e-5 -4.646444144875551e-5 5.876923237244785e-5 -7.016218446645625e-5 0.00019826421916966248 -3.286947471772562e-5 -3.974322103297217e-5 -2.4438507945003108e-5 2.255867579967638e-5 0.00011108887551834003 5.277504746934524e-5 7.933255833355545e-5 2.003548283842532e-5 -6.645480029967744e-5 -6.999817424956279e-6 0.00011820422364412875 9.69640080702199e-5 1.8908635273852328e-5; -8.486052994885402e-5 -7.737175785041483e-5 7.427861345363422e-8 0.00012142668545675406 1.600584126390685e-5 6.99815105526799e-5 0.00023448872942208237 8.796499294658116e-5 -8.548886709650139e-5 -7.469302674326115e-5 5.586777531971232e-5 -0.00010647532637267152 -0.00018843771560386738 -5.5976273341001974e-5 -7.476700868028274e-5 0.0001393987009414556 -9.13492289491845e-5 -6.47186638325356e-5 0.0001057309242078191 3.610039431805805e-6 -1.5299784916351415e-5 -0.00018568111719832804 -7.31839785821648e-5 0.00027632122199247513 0.00012021763230404736 5.1702627917622284e-5 -0.00015020695301973066 8.390507408148566e-5 0.00010316652022266068 -3.6070348340215e-5 -0.00012002366149135908 -9.556120077650152e-5; -3.999182207022907e-5 5.277139154515871e-5 -0.0001644484849586463 6.523799913140734e-5 1.9948375318419084e-5 0.00011361958115032944 0.00013464484310845507 -3.979737574098043e-5 -0.00010163830791696434 -1.788524912824169e-5 -1.4312477453287711e-5 -7.552873046929925e-5 -4.6883307170147165e-5 -1.8027443167895913e-5 0.0001068152074529489 -2.835802781355369e-6 -6.887966936317726e-5 -0.00011816624954205445 -0.0001309216435239946 -9.027377144063804e-5 8.218069004406936e-7 -9.615565557594542e-5 -8.742980333606876e-5 3.5564348416138294e-5 0.00010679890930789138 2.5873216707473924e-5 -0.00018464625225729516 -6.327655626789769e-5 4.595174339417498e-6 -9.950842412050011e-5 -5.420251671437281e-5 3.326929674387463e-5; -0.00022727771495557235 0.00017529390395065793 0.00016761383965058297 -5.8248550274586655e-5 -4.59687019683919e-5 0.00017409405488433077 1.971667158795579e-5 1.8473328666003092e-5 7.250253626989343e-5 0.0001614910194548371 0.00015986593432170882 0.00010112154056694668 -0.00016122064910451399 -7.106466576079234e-5 -0.00010755856995092741 1.8157826772948307e-5 -7.950060019356993e-5 -3.4696039008857553e-5 0.0001607346836608455 -6.97042799656672e-5 0.00011083964425051672 2.6669111022780655e-5 5.391146412148071e-5 3.538563221340776e-7 -4.329945143399669e-5 5.697488602162387e-5 -0.00011524089707383405 9.782184285722561e-5 6.0727253796339355e-5 -0.0001364255243315878 8.288000178435708e-5 1.4838728631729218e-5], bias = [-9.194243308590136e-11, -3.396576703543154e-9, -2.0743602346155134e-9, -5.130806887056088e-10, -2.449719586445846e-9, -1.1396788324687862e-9, 1.4993354287049843e-9, -8.008890825620921e-10, 2.863678187177434e-10, 1.6666073486012606e-9, -1.1336383355409235e-9, -6.362431588288651e-10, 1.4681747241238484e-9, 1.551394868203873e-9, 3.188849759275915e-9, 1.1781225339877345e-9, -2.236942984672716e-9, 1.258895656188114e-9, -2.8027921589612606e-9, -2.7087347124945775e-9, -5.438773018236728e-9, 4.045391210127416e-9, 1.081998568780188e-9, -1.5357854349978545e-9, 4.6230513681354925e-9, 1.0505965455375566e-9, -3.3107785704518504e-9, -3.3101725285143975e-10, 2.974365066793499e-9, -4.396661100640662e-10, -3.3009038632019544e-9, 2.406224665747909e-9]), layer_4 = (weight = [-0.0008050867679491907 -0.0006003281791278425 -0.0005664765249521214 -0.0008557109106108843 -0.0007260928058807353 -0.0007529571845772793 -0.0006877776912535193 -0.000587941431472655 -0.0007705230791265141 -0.0006539581393544627 -0.0006990347219902409 -0.00044745446757750435 -0.0007998285667556431 -0.0007089938576758136 -0.0006840353828542414 -0.000773283548925589 -0.0006259676180108573 -0.0006931757325370636 -0.0006553955630401797 -0.0006228756501960467 -0.0005972096395446194 -0.0008796093594784358 -0.0007883477545033008 -0.0008554602062898059 -0.0006156640824808987 -0.0006320763178464079 -0.000749534045282747 -0.0007587072507104326 -0.0006976996513545098 -0.0007141815169979892 -0.000613206169012766 -0.0005774910356052486; 0.0001151507888064497 0.0002465058619819954 0.00015941985972456582 0.0003415413202411553 0.00023439975030339097 0.0001843017383987784 0.00028671829622706945 0.0001773583810238147 9.389077479298205e-5 0.00018576719297408518 0.0003235595181861572 0.0003279949780254968 0.00030570452927874557 0.0002313004490709848 0.00030849984117980165 0.00010237137364456832 0.00025704596253383726 0.00013831964527490713 0.0002619438845034549 0.00028174063083284 0.00014718919969858437 0.00011022483588599187 6.508469731536562e-5 0.00020312675651456698 0.00022950582473671774 0.0002921498874388377 0.0003412260425178879 0.0003567056977489543 0.0004115305676103287 0.00029146182868047046 0.00031308826021274696 0.00012674166807889032], bias = [-0.0006951471358973036, 0.00024377508026377463]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
    compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(
        ax,
        tsteps,
        waveform_nn_trained;
        marker=:circle,
        alpha=0.5,
        strokewidth=2,
        markersize=12,
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb,
    )

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 128 default, 0 interactive, 64 GC (on 128 virtual cores)
Environment:
  JULIA_CPU_THREADS = 128
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 128
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.