Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux, ComponentArrays, LineSearches, OrdinaryDiffEqLowOrderRK, Optimization,
      OptimizationOptimJL, Printf, Random, SciMLSensitivity
using CairoMakie

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end
one2two (generic function with 1 method)

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end
soln2orbit (generic function with 2 methods)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end
d_dt (generic function with 1 method)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0)<1e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio1 "mass_ratio must be <= 1"
    @assert mass_ratio0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end
compute_waveform (generic function with 2 methods)

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defiing a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32),
    Dense(32 => 2; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32))
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-2.9276163f-5; -0.00014470128; 0.00014843476; 3.4592642f-5; -1.1858729f-5; -1.7285209f-5; 1.736001f-5; 9.485369f-5; -4.376713f-5; -6.9653586f-5; 4.8160837f-5; -0.00019014961; 4.0163162f-5; 6.106093f-5; -2.9839544f-5; -0.00013815572; -4.4137305f-5; -9.289838f-5; -4.5374247f-5; -9.694301f-5; -4.039668f-5; -8.097201f-5; 5.0551913f-5; 4.641294f-5; 1.2114686f-5; 0.000119926925; -0.00013948248; -0.00016010947; -0.00014308598; 1.841357f-5; -1.4092393f-5; -0.00016102797;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-0.00016946514 2.4912248f-5 0.0002281766 0.00010093294 -2.4289955f-5 -0.00013716482 5.286793f-5 -3.2660508f-5 4.585179f-6 4.0783f-5 -7.801501f-5 -3.7184705f-5 -7.878655f-5 -9.8163044f-5 0.000107228814 -6.1678606f-5 -1.2657113f-5 2.1680546f-5 -1.9371637f-6 3.1090214f-5 -0.0001712425 1.8760426f-5 -3.3667144f-5 -2.3667795f-5 4.0155923f-5 -0.00011868753 -2.5239862f-5 0.00013228183 1.7983552f-5 1.2745495f-5 -0.000114685965 -3.964395f-5; 0.0001512541 1.660323f-5 0.00010304205 5.3502106f-5 0.0001632344 -0.00013316033 -5.7535217f-6 -3.9056737f-5 1.145625f-5 -8.017172f-5 1.5073089f-5 -1.1803615f-5 -0.00013267988 -0.00011870696 -5.9680882f-5 0.00025988382 -7.511045f-5 0.000106434025 7.298311f-5 -0.00015029105 -0.00022340586 -0.00017725091 7.672516f-6 1.3132095f-5 2.0353227f-5 -7.1353665f-5 0.00014800244 -5.9993894f-5 0.0001257671 -0.00016376507 -2.6806416f-5 -7.1914415f-5; -9.9799945f-6 -0.00012297796 5.5209977f-5 -7.5869566f-6 -4.0776973f-5 1.904344f-5 0.00017733767 0.00013865864 0.000155722 -6.6295266f-5 0.00021649456 1.8911704f-5 -8.113935f-5 -9.474948f-6 -5.4892596f-5 -0.00013671377 8.986712f-5 -2.0558351f-5 0.0001031792 4.5437482f-5 0.000104551946 1.5504267f-5 -5.841583f-5 6.2689745f-5 -5.4600805f-5 -6.869095f-5 -0.0002095011 0.00023651766 9.836202f-5 -4.386535f-5 8.823765f-5 6.670256f-5; 4.3504897f-5 -8.730537f-5 -6.9112524f-5 -6.0734812f-5 -5.3464024f-5 -2.0550038f-5 -3.1090706f-6 4.019219f-5 -1.7301603f-5 -8.629966f-5 -2.836264f-5 -2.392893f-5 7.450434f-5 -4.3774444f-6 6.4377346f-5 3.208238f-5 -1.6641683f-6 -7.463204f-5 1.3145872f-5 -4.653593f-5 -0.00013099286 8.2962506f-5 5.7654026f-5 -0.00010770348 -9.409229f-5 0.00019265279 -0.00013838106 -0.0002485651 -6.3249565f-5 0.0002241377 0.000116163326 -2.3299538f-6; 0.00019415283 -2.162553f-5 -5.3894408f-5 -5.8755245f-6 5.645222f-5 -0.000134717 0.00017449349 2.4808984f-5 -0.00014051922 0.00014285598 0.00013257611 -1.6330732f-5 4.773399f-6 -0.00011604026 5.034085f-5 5.104096f-5 0.00012121577 0.0001558422 2.0852893f-5 -7.723827f-5 -9.197578f-6 -9.5875526f-5 -1.9082529f-5 0.00016466895 0.00012229946 -2.3732535f-5 -2.0392215f-5 -4.9906765f-5 -0.00011683633 1.2859196f-5 -7.4363184f-6 1.859679f-5; 7.5515068f-6 -0.00028432516 -2.9372213f-5 -5.1813997f-5 -0.00018724542 1.3697345f-5 -1.4418723f-5 0.00013361816 -3.7624755f-5 0.00021185099 0.00013392954 3.6567166f-5 3.037784f-5 -0.00019799967 -0.00012865767 -8.1809616f-5 5.406697f-5 3.711895f-5 -2.4233306f-5 -0.00022279468 7.233523f-6 -3.225397f-5 -2.8647299f-5 -1.0115787f-5 -3.0498992f-5 0.00013277003 0.00020109364 -6.6292254f-5 -5.1289673f-5 -5.3001753f-5 -0.00012577773 9.9247314f-5; -2.3178036f-6 -7.687186f-6 -5.1439714f-5 0.00011979333 0.00015177466 0.00015731053 -7.6085744f-5 0.00015369501 -9.555848f-5 -3.607286f-5 -4.9253504f-5 -1.8041696f-5 6.991593f-5 3.5654168f-5 -4.7769445f-5 -7.237653f-5 -2.9420857f-5 1.2790264f-5 0.00012899125 -6.1023162f-5 0.00015690745 -4.945981f-5 -1.4179542f-5 -8.759225f-5 9.7526136f-5 -1.9565903f-6 -5.0793733f-5 -0.00016110588 4.566215f-5 2.8399101f-5 0.00013972739 0.00011016234; 2.9196686f-5 7.642594f-5 -8.349229f-5 0.00017957133 -2.7452259f-5 -1.6748971f-6 6.1725346f-5 -0.00010674283 -0.00013107542 -3.9523333f-5 -8.672673f-5 -1.5556727f-5 -9.710932f-5 4.1275245f-5 0.00017871514 8.17672f-5 -0.00013292879 1.6399494f-5 4.0301347f-5 -2.7159198f-5 -0.000103962724 5.983994f-5 -8.955689f-5 -9.841677f-5 1.9365469f-5 2.3809042f-5 7.141543f-5 9.364005f-5 0.00018962678 -8.4845706f-5 -0.00016112339 -1.2022733f-5; 1.9647343f-6 -3.2669395f-5 -8.797973f-6 9.341806f-6 -0.00015375005 7.698852f-5 7.486651f-5 -0.00013987102 -0.00019667875 -2.238239f-5 4.559899f-5 -1.9144058f-5 -0.00013275119 -0.00016855092 -5.069719f-5 -1.0706748f-5 0.00018355672 -2.820033f-5 -7.433462f-5 7.797535f-8 -4.6601952f-5 -0.00015447536 0.0001907614 -1.04346855f-5 2.3944443f-5 8.4152685f-5 -3.7447113f-5 3.0666983f-5 -4.516903f-5 2.8111179f-5 -1.9734342f-5 0.000112924514; -8.252193f-5 -1.795953f-5 -0.00013818359 3.1200234f-5 -0.00026937507 -0.00014262713 -6.797775f-5 0.0001408827 -0.0002810724 6.112476f-5 0.000107600245 5.020993f-6 -0.000106695494 -0.00012427135 3.490826f-5 -0.00016379623 -8.1773826f-5 0.00020735181 -7.721253f-5 -4.8861253f-5 -6.895215f-6 0.00011330645 -4.1926498f-5 2.3337252f-5 9.280676f-5 5.7729292f-5 -0.00011107384 7.1339855f-5 -2.4411373f-5 -3.926188f-5 2.9714585f-5 -7.9509095f-7; 1.8110704f-5 -0.000120885496 -4.285573f-5 0.00016244363 9.314142f-5 8.586071f-5 -4.4722315f-6 -0.000114768234 -7.937131f-5 0.00015082624 -6.0473478f-5 0.0001632919 -0.00011608852 2.05564f-5 6.1343594f-6 1.7302602f-6 0.0001476379 -5.4723714f-5 -7.123665f-6 -7.6441254f-5 0.00016618548 -8.269644f-5 0.000100850986 0.00011631115 0.00013357063 -0.00013563948 -0.0001600738 -5.5440578f-5 2.1004005f-5 7.0929887f-6 -0.00015138503 2.019349f-5; 0.00028460275 -2.6439764f-5 9.4596f-5 0.00010042319 9.838127f-5 -4.9355574f-5 0.000121083816 -1.2991461f-5 -3.497792f-5 6.0534465f-5 -0.00015050844 0.00016724484 0.00011117113 4.3651555f-5 -7.2597184f-5 0.00010049173 -0.0001535547 -6.205443f-5 -6.4319924f-5 -3.4689037f-5 -2.1908574f-5 -7.151305f-5 -8.2381346f-5 -1.7662906f-5 -0.000104346436 6.9928254f-5 -0.00012548716 3.407672f-5 -0.00012329228 7.893362f-5 -0.00011192011 -1.2783545f-5; -4.2525553f-5 -2.6044196f-5 -1.2113294f-6 8.0146885f-5 -7.3611576f-5 -4.4156245f-6 -8.535213f-5 -6.8296195f-6 -2.8941366f-5 -1.7703442f-5 -2.3878123f-5 1.2374563f-5 -5.4992597f-5 7.128881f-5 2.0617916f-5 -6.815324f-5 4.3576692f-5 0.00013350918 -2.9274635f-5 -0.00014633214 -0.0001319911 -0.000113457594 9.034094f-5 -4.096155f-5 4.646866f-5 5.6617206f-5 0.0001261131 0.0001911314 -3.1137202f-5 -0.00025507086 5.8884005f-5 -1.0261165f-5; 2.9590186f-5 -1.4459381f-5 3.677444f-5 0.00010160915 0.00022279954 -0.00012413115 5.3043645f-5 -0.00017414083 -6.534398f-5 -7.934843f-5 -4.0925534f-5 -3.394877f-5 -0.00010440473 -7.657355f-5 0.00012948994 0.00013067611 -6.71944f-5 8.572334f-5 -8.479514f-5 2.7201539f-5 3.7087952f-6 -0.00011199342 -6.5648856f-5 -2.128895f-5 -4.251491f-5 -0.00016193387 9.383741f-5 -1.851205f-5 3.125302f-5 1.3916076f-5 7.071896f-5 -9.763264f-5; 0.000168029 -0.00013815773 -0.0001228352 7.4446034f-5 6.374121f-5 -7.050717f-5 -0.00011515648 9.380805f-5 -2.0006017f-5 -4.1704003f-5 0.00014062777 -3.7231654f-5 -2.0779182f-6 5.084769f-5 9.076728f-5 -0.00017059194 -7.7709185f-5 0.00011052505 9.408469f-5 2.5597873f-7 -2.9635285f-5 -8.5000516f-5 0.00019938443 5.0181297f-5 0.0002381011 -0.00013768775 -1.118049f-5 -9.3277486f-5 8.3607134f-5 6.7127396f-5 -0.0001190456 4.976805f-5; 8.748496f-5 -1.4466068f-5 -0.0001488436 4.136578f-5 -8.373026f-5 0.00012277414 9.6217336f-5 -9.089332f-5 5.062788f-5 5.03895f-5 2.8930228f-5 -8.2116705f-5 4.7952108f-5 -6.5432156f-5 -5.9124057f-5 -0.0001341505 -9.801385f-7 0.00013293742 6.212999f-5 5.9503534f-5 -0.00010536116 -9.576269f-5 -6.155408f-5 7.0475566f-5 4.0892617f-5 8.2677965f-5 0.00018170461 0.000111374524 5.5483233f-5 -4.0171213f-5 -0.0001107968 3.7162186f-5; -9.160611f-5 6.785392f-5 0.00030410732 6.8954505f-5 -0.00016815854 0.00021792452 -5.325028f-7 0.00013354694 -0.00023282779 0.000106911604 -0.00012711986 1.7615257f-5 -9.178385f-5 -0.00012488321 -8.2487946f-5 -0.00013303995 -0.000107249085 6.0561815f-5 -8.799047f-5 5.5457658f-5 -0.0002811345 -4.0141622f-5 -0.00010672232 8.0678656f-5 5.11301f-5 1.1686408f-6 7.206097f-5 1.1047553f-5 5.647971f-5 -0.00015219022 9.37925f-5 -8.296069f-5; 6.730224f-5 -6.422401f-5 8.4156156f-5 5.463516f-6 -6.268772f-5 -5.2629843f-5 9.578605f-5 -0.00011775905 7.7499804f-5 8.281399f-5 -0.000104484774 0.000116798394 -3.099093f-5 -0.00010582067 -0.00014998297 7.089235f-5 0.00016209303 0.0001327115 -2.434799f-5 -5.209492f-6 -3.329167f-5 0.0002467426 -8.194908f-5 0.00021488442 0.00027132954 -6.97252f-5 0.00010971671 8.806264f-5 -3.6547743f-5 -1.5638077f-5 1.2657566f-5 -5.1182567f-5; 0.00012042111 -1.9867357f-5 3.4209035f-5 -1.4937978f-5 -5.3598535f-5 7.2995103f-6 -0.0001445497 -5.040829f-5 -0.000105357336 7.7731194f-5 1.3394526f-5 -2.9901377f-5 9.223039f-6 -7.6247714f-5 7.585378f-8 0.0001210742 -0.00013181643 -3.3257442f-5 0.00019372163 -3.960856f-5 0.00011290148 -7.323326f-5 -6.662526f-5 9.490417f-5 -0.0002857896 0.00010588322 -3.9878167f-5 -6.543024f-5 6.961418f-5 -0.000106186286 3.9486516f-5 -7.173013f-5; -4.305229f-5 -0.00017745829 0.00012090801 7.133958f-5 3.1223888f-5 4.0892566f-5 -2.9516958f-5 0.00010821676 -1.3132327f-6 6.1162864f-5 -2.3300237f-5 7.1669856f-5 2.4932717f-5 2.1029437f-5 9.867136f-5 -1.0670836f-5 2.4621364f-5 -0.00015130143 -0.00010958983 8.193417f-6 2.3315391f-5 5.8365786f-5 -2.2881388f-6 4.4671666f-5 -0.00013702021 3.2824464f-5 -0.00024307592 -6.6675384f-5 -0.00023315162 0.000182628 -0.00011498523 -1.5660324f-5; 8.8485154f-5 -7.87515f-5 0.0001219041 -0.000112806185 -7.75084f-5 3.2478845f-5 -0.00010769345 -7.773728f-6 -0.0002767497 -9.60085f-5 8.779435f-6 -4.374673f-5 5.208385f-5 8.537889f-5 -0.00016875981 0.00012958239 0.00013256354 -4.1574567f-5 -4.2843414f-5 -4.168365f-5 -0.00014662794 -0.00019980507 -0.0002027488 4.0457013f-5 -2.0207413f-5 -9.392329f-5 -2.8620216f-6 4.5693887f-5 -6.6445464f-5 4.341217f-5 0.0001002438 1.3446011f-5; 7.9079255f-5 6.5052205f-5 -2.5872348f-5 7.911964f-6 -0.00012442697 -4.7086545f-5 0.00015745242 6.085674f-5 2.6768013f-5 0.00011875587 8.210648f-5 4.319058f-5 -9.4335235f-5 -4.564757f-5 -1.769042f-5 -0.0001334264 3.4638945f-6 6.1022227f-5 0.00029720875 9.1079266f-5 2.7541959f-5 4.065741f-5 -0.000112822294 -0.00015146413 -0.00010087945 3.3730586f-5 -4.065618f-5 8.823788f-5 0.00013523837 2.7742057f-5 -0.0001128041 8.724441f-5; -8.297213f-5 -0.0001709167 5.8579088f-5 0.00012766622 -1.4375458f-5 9.6724805f-5 8.964048f-5 -8.1480306f-5 -0.00011447967 -4.02135f-5 2.9584833f-5 -6.6090906f-6 1.6590688f-5 0.00014489393 -5.1401563f-5 6.375396f-5 0.0003661303 -0.00011003656 4.9068145f-5 -0.00011411763 -6.019708f-5 -0.00010747153 6.851204f-5 -0.00011836779 -0.0001345824 0.00011131788 0.00014241395 0.00021312252 1.6290094f-5 0.00015748375 5.9412732f-6 -2.9661871f-5; -7.111297f-5 0.00016897649 4.5601442f-5 4.6413505f-5 0.00020944439 -7.2972364f-7 -6.2347455f-5 -6.0476773f-6 3.9570736f-5 6.893292f-5 9.637564f-5 -8.518468f-5 -6.6553046f-5 -1.7054572f-5 -1.5606087f-5 -0.00011180231 6.583345f-5 1.2563003f-5 0.00023637924 -0.00021256888 -0.00016168063 -7.641995f-5 0.00016093465 0.00014479135 2.7582323f-6 0.00010384723 -7.229928f-5 -3.852058f-5 -0.00024336421 -9.978207f-5 5.5344866f-5 -0.00013538974; -1.52073835f-5 9.956009f-5 -0.00019845463 -0.00011830686 -1.8564717f-5 -9.006222f-5 0.00015392146 7.550314f-5 2.0292051f-5 0.00012704598 -3.3462922f-5 0.00012584701 -4.505508f-5 -2.9447376f-6 5.1782677f-5 4.8512848f-5 0.00016178969 -3.6978865f-5 2.234746f-5 5.186636f-5 -3.7480797f-5 -9.6604424f-5 6.232042f-5 -6.6195906f-5 0.00012779115 -3.110201f-5 -0.00011296843 -3.235911f-5 5.327647f-5 2.2578095f-5 2.589611f-5 -1.0890285f-5; 0.00011879834 5.9195583f-5 -0.00010726192 4.7435802f-5 0.00019153542 -0.00021662521 -1.7992392f-5 -3.2689317f-5 -5.2614223f-6 1.7733724f-5 1.3076505f-5 -8.361196f-5 1.6357973f-5 4.7886137f-5 6.770588f-6 -8.203203f-5 -8.441587f-6 4.5295787f-5 -4.2741676f-6 -0.00018090726 7.8201854f-5 -0.00010685346 -5.2733467f-5 -6.3876364f-6 5.012855f-5 -0.00012225102 9.839198f-5 -2.2415636f-5 0.00019675789 -8.043413f-5 -0.00011417455 -4.6947443f-5; 4.59079f-5 -3.652283f-5 -1.700653f-5 0.00017122744 -0.00014699806 0.00016689293 2.430295f-5 -1.5734955f-5 -4.539659f-5 -9.3341594f-5 -4.5096407f-5 -3.025975f-5 0.000119242155 -0.00019205357 -8.723956f-5 8.788374f-5 -8.823336f-5 -1.4689764f-5 -2.3874278f-5 8.178956f-5 5.156825f-5 -3.700158f-5 -6.0141992f-5 -5.0731505f-5 -1.7515906f-5 -0.00014513677 9.497544f-5 -1.001624f-6 9.146486f-5 -0.0001249632 -3.89889f-5 -0.00025021684; 0.0001354803 -2.21497f-5 3.5938887f-5 -1.8060859f-5 9.212001f-5 -0.00010094274 -2.2939673f-6 4.470639f-5 -0.00020590582 4.883026f-5 -3.5497436f-5 0.00012864712 -6.0074846f-5 7.320453f-5 -8.936635f-5 -5.018194f-5 -0.0002341994 -2.8768583f-5 -1.50020105f-5 0.00014375686 2.8918044f-5 -5.8867947f-5 7.565877f-5 -0.00013346903 6.49175f-5 -6.339532f-5 -1.3315627f-5 2.7028884f-5 -0.00016094398 -7.474623f-5 5.202415f-6 -2.7075419f-5; -2.5510604f-5 5.2090945f-5 7.4884534f-5 1.6195687f-5 9.85802f-5 4.7424397f-5 -3.788753f-5 1.5447218f-5 -0.00010373644 0.000101143254 3.4004996f-5 2.5306575f-5 6.7894252f-6 -4.0598592f-5 7.883938f-5 7.265421f-5 7.359494f-6 -1.100492f-6 -8.875312f-5 -4.6747966f-5 5.3878288f-5 4.0828138f-5 -0.00013919605 -2.0343576f-5 -1.4232318f-5 4.5741606f-5 6.602152f-5 -0.00014557518 -5.3306332f-5 -0.00022214504 0.00012549071 3.1803283f-5; -0.00012717294 -0.00012340643 -0.00018226792 2.31832f-5 -1.0850548f-5 8.6810476f-5 -1.9986248f-5 -0.00029390838 -8.316086f-5 -8.404781f-5 5.963333f-5 7.651245f-5 0.00023245024 -0.00014850235 5.1012343f-5 0.00011065739 -4.467614f-5 8.46267f-5 1.9409466f-5 -0.000111268666 -2.1146592f-5 -0.00011514195 -0.00017360986 0.00020077723 0.00017670341 6.5713255f-5 5.3843247f-5 -2.5984838f-7 2.8507931f-5 -3.255625f-5 0.00012269458 0.00015630561; 2.0229998f-6 0.00012384496 -0.00015452674 -0.000131889 -8.8320805f-5 -0.000106323045 -1.899296f-5 -8.2902916f-5 4.268664f-5 -0.00018840292 3.4501263f-5 -0.000108724984 -3.1409178f-5 -8.263111f-6 -0.00010497908 0.00011493257 0.00016281095 -1.307123f-5 5.786231f-5 -2.4616236f-5 -9.0556845f-5 2.2604015f-6 -0.00015207827 2.689226f-6 0.0001791028 0.0001399229 0.00013914745 -6.636942f-6 -1.5621454f-5 0.00019278433 7.645258f-5 3.7531645f-5; -5.9790967f-5 8.343147f-5 -6.472752f-5 -4.28808f-5 4.879027f-5 8.209636f-5 -4.524478f-5 -6.999263f-5 -0.000140492 6.8170552f-6 1.2311136f-5 0.00015752351 -8.914805f-5 5.064584f-5 0.00017089977 -0.0001195835 7.331507f-5 -6.3800646f-5 8.714693f-5 0.00015066953 -3.079067f-5 6.970002f-5 -2.014589f-5 -0.00010271843 -9.425832f-5 3.822559f-5 8.1351f-5 -2.0711203f-5 -3.705162f-6 5.6103196f-5 4.3267428f-5 -1.5264943f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[-0.00015358356 2.5440286f-5 -1.5099014f-6 -0.0001718956 -5.014758f-5 7.835021f-6 -4.246301f-5 6.6014574f-5 -2.0184785f-5 -7.831344f-5 -3.6320532f-6 0.00015751093 -2.3280481f-5 5.8419395f-5 0.00012825067 4.280869f-5 -0.00010744992 8.04928f-5 3.7062415f-5 -7.174852f-5 -6.132533f-5 -7.797381f-5 2.6975187f-5 -6.605608f-5 -0.00024206158 -1.9831778f-5 0.00010608923 0.00010708461 -0.000106148604 -5.3684134f-6 5.1158346f-5 5.293993f-6; 5.015503f-5 -3.176004f-5 -9.484684f-5 1.0218586f-5 -0.00022628634 -8.2961575f-5 2.4732726f-5 -9.160019f-5 -9.976108f-5 -4.910359f-5 4.5812416f-5 -0.00018301223 0.00012392302 -3.2352258f-5 -0.00010034683 -3.6988376f-6 0.00020076377 -0.00023366165 -1.19315555f-5 -0.00019355529 3.8368307f-5 4.463438f-5 3.9215014f-5 3.489989f-5 9.4386814f-5 -0.00019775279 6.0036265f-5 7.550404f-5 -2.8050024f-5 -1.3111858f-5 0.00010918907 0.00014535701], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most DL frameworks, Lux defaults to using Float32, however, in this case we need Float64

julia
const params = ComponentArray(ps |> f64)

const nn_model = StatefulLuxLayer{true}(nn, nothing, st)
StatefulLuxLayer{true}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),  # 64 parameters
        layer_3 = Dense(32 => 32, cos),  # 1_056 parameters
        layer_4 = Dense(32 => 2),       # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)

    axislegend(ax, [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"]; position=:lb)

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end
loss (generic function with 1 method)

Warmup the loss function

julia
loss(params)
0.0006897118474314336

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end
callback (generic function with 1 method)

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob, BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback, maxiters=1000)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-2.9276163331809014e-5; -0.00014470127644013674; 0.00014843475946702827; 3.4592641895872525e-5; -1.185872861241039e-5; -1.728520874166418e-5; 1.7360009223896223e-5; 9.485369082541121e-5; -4.376712968213737e-5; -6.96535862516385e-5; 4.8160836740840806e-5; -0.00019014961435449674; 4.016316233900718e-5; 6.106092769181081e-5; -2.983954436785467e-5; -0.00013815572310711915; -4.41373049397234e-5; -9.289837907999425e-5; -4.537424683797537e-5; -9.69430111580954e-5; -4.0396680560611356e-5; -8.097200770854079e-5; 5.0551912863718906e-5; 4.641294071914489e-5; 1.2114685887343882e-5; 0.00011992692452623346; -0.00013948247942611704; -0.00016010947001622992; -0.0001430859847459253; 1.8413569705429287e-5; -1.4092393030296336e-5; -0.00016102797235348676;;], bias = [-7.460802915947525e-18, -4.304587917125053e-17, -1.362933431879802e-16, -3.1708135878033885e-18, -1.557729256476062e-17, -5.155442610487028e-18, 2.9124980683413753e-17, -9.874262897525754e-18, -1.1534841312919236e-16, -2.6572285876769378e-17, 2.492220085581453e-17, -1.152322608776451e-16, 6.8105895727175745e-18, 6.697451088116373e-17, -5.768545198316634e-18, 8.120832048383503e-17, -8.272780660918924e-17, -4.94434616638093e-18, -7.117673086021156e-17, -5.942343835448956e-17, -5.783731212507261e-17, -9.591854041206341e-17, 2.2541271302125538e-17, 2.3759090343156087e-17, 1.2404365043398456e-17, -1.0960451928735808e-17, 4.310957389524347e-17, -1.0873075278899103e-16, -7.949012531383295e-17, 1.702671822662378e-17, 1.6266662650087774e-19, -1.5438944361363268e-16]), layer_3 = (weight = [-0.0001694666991170056 2.491069226384469e-5 0.0002281750510519919 0.00010093138716914451 -2.429151045359035e-5 -0.00013716638040929384 5.286637479006025e-5 -3.266206313124698e-5 4.583623467657302e-6 4.0781445750360764e-5 -7.801656302032245e-5 -3.718626085168992e-5 -7.878810556573082e-5 -9.81645991486974e-5 0.00010722725874078204 -6.168016097940758e-5 -1.265866849277138e-5 2.167899019699684e-5 -1.938719105932036e-6 3.1088658761328024e-5 -0.00017124405549108268 1.8758870101045226e-5 -3.366869914313017e-5 -2.3669350593958107e-5 4.015436734337996e-5 -0.00011868908604806784 -2.5241417642931948e-5 0.00013228027255540796 1.7981996097702564e-5 1.2743939623792598e-5 -0.00011468752036304506 -3.964550542145872e-5; 0.00015125284039897482 1.660196426458189e-5 0.00010304078819247471 5.350084097955956e-5 0.0001632331411681844 -0.00013316159242747735 -5.754787023015482e-6 -3.905800190039008e-5 1.1454984590109942e-5 -8.01729831866195e-5 1.5071824011848515e-5 -1.1804880384410055e-5 -0.00013268114639429707 -0.00011870822276903564 -5.968214759121083e-5 0.0002598825528846906 -7.51117179308305e-5 0.00010643275959666774 7.298184277971891e-5 -0.0001502923107611203 -0.00022340712825751264 -0.0001772521784434294 7.671251071607067e-6 1.3130829532038557e-5 2.035196189610358e-5 -7.135493001185165e-5 0.0001480011712854582 -5.999515928776278e-5 0.00012576583026459887 -0.00016376633652469527 -2.6807681249174493e-5 -7.191568078921438e-5; -9.977306091862746e-6 -0.00012297527069850072 5.521266570039751e-5 -7.58426818891506e-6 -4.0774284627225576e-5 1.9046128671213533e-5 0.00017734036004014083 0.0001386613257073884 0.00015572468322544916 -6.629257774137593e-5 0.0002164972501651529 1.8914392001609093e-5 -8.113665904773746e-5 -9.472259136497246e-6 -5.488990797119499e-5 -0.00013671108541702867 8.986980896053459e-5 -2.055566261505273e-5 0.00010318188645952926 4.544017059571012e-5 0.00010455463410664001 1.5506955549527535e-5 -5.8413141288733674e-5 6.269243380130806e-5 -5.459811660501394e-5 -6.868825954543028e-5 -0.00020949841206120375 0.00023652035082531503 9.836470775380593e-5 -4.386266306199369e-5 8.82403364249669e-5 6.670524811678883e-5; 4.350335474972974e-5 -8.730691443397213e-5 -6.911406635972165e-5 -6.073635413977871e-5 -5.34655660934807e-5 -2.0551580387662565e-5 -3.110612720236698e-6 4.019064762778424e-5 -1.7303145437296992e-5 -8.630120246871742e-5 -2.8364181712647277e-5 -2.393047143904199e-5 7.450280135762345e-5 -4.378986527996193e-6 6.437580340806758e-5 3.208083616719006e-5 -1.6657104246471712e-6 -7.463357870201756e-5 1.3144329745585127e-5 -4.653747201282668e-5 -0.00013099440512215316 8.296096384091953e-5 5.765248387851268e-5 -0.00010770502522555686 -9.409383310904401e-5 0.00019265124901230566 -0.00013838260165848393 -0.00024856664786547663 -6.325110699017106e-5 0.00022413615115018896 0.0001161617841667901 -2.331495902939292e-6; 0.00019415488173542386 -2.1623480226882805e-5 -5.3892357702755245e-5 -5.873474440184136e-6 5.645426977714413e-5 -0.00013471495362175749 0.0001744955353002052 2.481103370661612e-5 -0.00014051716602118914 0.00014285802722731182 0.00013257816140828399 -1.6328682197539107e-5 4.775448867414621e-6 -0.00011603821071042097 5.03428983007076e-5 5.1043009132226224e-5 0.00012121782315950608 0.00015584424368859524 2.085494275667932e-5 -7.72362199699579e-5 -9.195527984436839e-6 -9.587347625932264e-5 -1.908047847103133e-5 0.00016467099890867006 0.00012230151156252138 -2.3730484774602935e-5 -2.0390165393510847e-5 -4.9904714562622304e-5 -0.00011683428050895756 1.286124648465651e-5 -7.434268302337711e-6 1.8598839285509056e-5; 7.54937892186547e-6 -0.00028432728467387883 -2.9374341095631934e-5 -5.18161246808828e-5 -0.00018724754372026147 1.3695217064266855e-5 -1.442085121463815e-5 0.00013361603160509216 -3.7626883197773225e-5 0.00021184885955070166 0.00013392741348714415 3.6565037960012934e-5 3.037571187367457e-5 -0.00019800180197573328 -0.00012865980133294384 -8.18117437965229e-5 5.4064843795530534e-5 3.711682111960186e-5 -2.4235434096154636e-5 -0.0002227968103333436 7.231395017769241e-6 -3.225609870507244e-5 -2.8649426524607933e-5 -1.0117914455732613e-5 -3.050111953864181e-5 0.00013276790232984094 0.00020109151628920906 -6.629438175232551e-5 -5.1291800985297336e-5 -5.3003881105653154e-5 -0.00012577986089392867 9.924518636316613e-5; -2.31594544990279e-6 -7.685327501782463e-6 -5.143785615607815e-5 0.00011979518881029139 0.000151776514984283 0.00015731238367928867 -7.608388564354998e-5 0.00015369687302930424 -9.555661831788144e-5 -3.607100093310302e-5 -4.9251645533753635e-5 -1.803983736998515e-5 6.991778632121753e-5 3.565602587380912e-5 -4.7767586649081796e-5 -7.237467521142683e-5 -2.941899852109033e-5 1.2792122150614663e-5 0.0001289931051801761 -6.102130369205839e-5 0.00015690931017937735 -4.945795167392758e-5 -1.4177683580295865e-5 -8.759039060484058e-5 9.752799364699734e-5 -1.9547321257706247e-6 -5.0791874811216525e-5 -0.00016110402516595026 4.566400734668042e-5 2.8400959352927485e-5 0.00013972924630119215 0.00011016420016723729; 2.9196161759836242e-5 7.642541399161264e-5 -8.349281475477899e-5 0.0001795708079420083 -2.7452783486870604e-5 -1.6754216027322012e-6 6.172482184535405e-5 -0.00010674335558963278 -0.00013107594046107883 -3.952385728428481e-5 -8.672725440067624e-5 -1.555725126183864e-5 -9.71098421893029e-5 4.127472069864733e-5 0.00017871461690571684 8.176667264831155e-5 -0.0001329293160403093 1.6398969220399892e-5 4.030082285805218e-5 -2.7159722466084945e-5 -0.00010396324856504262 5.98394138627282e-5 -8.955741273769558e-5 -9.84172953927835e-5 1.936494422022298e-5 2.3808517054359512e-5 7.141490581613874e-5 9.363952207547625e-5 0.0001896262541243892 -8.484623018251219e-5 -0.0001611239137297445 -1.2023257712486474e-5; 1.9628787283418902e-6 -3.267125090520545e-5 -8.799828960272704e-6 9.339950098116155e-6 -0.00015375190584043086 7.69866671687175e-5 7.486465595841375e-5 -0.00013987287117221835 -0.00019668060873689067 -2.238424555277676e-5 4.5597134645256164e-5 -1.9145913269540352e-5 -0.00013275304105652452 -0.00016855277523452245 -5.0699043880989754e-5 -1.0708603680779368e-5 0.00018355486803168497 -2.82021849345881e-5 -7.433647284803838e-5 7.611974107863905e-8 -4.660380751763424e-5 -0.00015447721695115445 0.00019075955036508023 -1.0436541073680072e-5 2.3942587846000243e-5 8.41508296941675e-5 -3.7448968454793015e-5 3.066512702911729e-5 -4.517088406306435e-5 2.8109323225840612e-5 -1.9736197158885594e-5 0.00011292265820497302; -8.252511170202944e-5 -1.7962709443847075e-5 -0.00013818676952711973 3.119705443178261e-5 -0.0002693782515961133 -0.00014263031325744603 -6.798093081660704e-5 0.0001408795143107588 -0.0002810755758217905 6.112158016559073e-5 0.00010759706501640453 5.017813328000382e-6 -0.00010669867324385792 -0.0001242745273883805 3.490507888282573e-5 -0.0001637994063456733 -8.17770050160956e-5 0.00020734863219341992 -7.721570718776387e-5 -4.886443266503634e-5 -6.898394440116806e-6 0.00011330327115512038 -4.192967744519835e-5 2.3334072557287853e-5 9.280358071928692e-5 5.7726112663768495e-5 -0.00011107701711796229 7.133667540996043e-5 -2.441455208088169e-5 -3.926505805784037e-5 2.9711405750371904e-5 -7.982704572586104e-7; 1.811127867954862e-5 -0.00012088492121433672 -4.28551543318e-5 0.00016244420067055686 9.314199223995767e-5 8.586128342545214e-5 -4.471657119160789e-6 -0.00011476765985020976 -7.937073723764564e-5 0.00015082681011522657 -6.0472903440263545e-5 0.0001632924754649575 -0.00011608794784292047 2.0556973597244608e-5 6.1349337575441135e-6 1.7308345850792596e-6 0.0001476384709367767 -5.4723139781296e-5 -7.1230906523351066e-6 -7.644068000258493e-5 0.00016618605104854855 -8.269586441924358e-5 0.00010085156054815791 0.00011631172628954574 0.00013357120724509096 -0.00013563890433756343 -0.00016007322892432838 -5.544000350523662e-5 2.1004579595729596e-5 7.093563001085727e-6 -0.00015138445496444967 2.0194065202349593e-5; 0.00028460287634618535 -2.6439636857389834e-5 9.459612541934341e-5 0.00010042331618126433 9.838139871237755e-5 -4.935544726220004e-5 0.00012108394352355926 -1.2991333422973621e-5 -3.497779121869295e-5 6.05345920495217e-5 -0.00015050830930601859 0.00016724496876224687 0.00011117125890552947 4.365168191827776e-5 -7.259705685908024e-5 0.0001004918557019892 -0.00015355457813609442 -6.205430341549298e-5 -6.431979650994816e-5 -3.4688910235558514e-5 -2.190844694166494e-5 -7.151292462265191e-5 -8.23812189618131e-5 -1.7662778335647244e-5 -0.00010434630908900917 6.992838083635912e-5 -0.0001254870332185191 3.407684547812529e-5 -0.00012329215329270936 7.893374626394315e-5 -0.00011191998433684975 -1.2783417477180814e-5; -4.252652764255623e-5 -2.6045170336074983e-5 -1.212304034125376e-6 8.014591036778444e-5 -7.361255087010096e-5 -4.416599160256977e-6 -8.535310155996581e-5 -6.830594094867772e-6 -2.894234024272709e-5 -1.770441634282784e-5 -2.38790977454551e-5 1.2373588440900105e-5 -5.499357178626069e-5 7.12878321501857e-5 2.0616941008549423e-5 -6.815421477942588e-5 4.357571776206197e-5 0.00013350820551574972 -2.92756100242765e-5 -0.000146333113327205 -0.00013199206990241488 -0.00011345856876862495 9.03399635427127e-5 -4.096252417359902e-5 4.6467685359013276e-5 5.661623105704775e-5 0.00012611212184135954 0.0001911304283276701 -3.113817695695721e-5 -0.0002550718342098919 5.88830301858904e-5 -1.0262139432524447e-5; 2.9588829554601835e-5 -1.4460738250293939e-5 3.677308281801486e-5 0.00010160779142785671 0.00022279818599378498 -0.0001241325115665253 5.3042287752365716e-5 -0.00017414219096502088 -6.534533518100268e-5 -7.934978553060427e-5 -4.0926890371731504e-5 -3.3950127410585946e-5 -0.00010440608807806051 -7.657490991953794e-5 0.00012948858434134935 0.00013067475460736726 -6.719575672144181e-5 8.572198216604255e-5 -8.479649464100705e-5 2.720018178258309e-5 3.707438415919802e-6 -0.000111994773626457 -6.565021235695194e-5 -2.1290306769908502e-5 -4.2516265104889915e-5 -0.00016193522956421162 9.38360500618509e-5 -1.8513406623625152e-5 3.125166417483017e-5 1.3914718827042294e-5 7.071760423747633e-5 -9.763399780460617e-5; 0.000168030182041757 -0.00013815654986299578 -0.00012283402341179602 7.44472153857984e-5 6.374238998635532e-5 -7.05059910867657e-5 -0.00011515529788791348 9.380922981311088e-5 -2.0004835303344055e-5 -4.1702821388545865e-5 0.0001406289547705721 -3.723047214205966e-5 -2.0767368166802844e-6 5.0848872863051496e-5 9.076846160699443e-5 -0.00017059076074272327 -7.770800314371685e-5 0.00011052622992085759 9.408586899756239e-5 2.571601416786806e-7 -2.963410318240721e-5 -8.49993348563408e-5 0.00019938561578986002 5.0182478819105454e-5 0.00023810228137538625 -0.00013768656665686275 -1.117930881062239e-5 -9.327630416717565e-5 8.360831541014846e-5 6.712857700425053e-5 -0.0001190444209239739 4.9769229892462384e-5; 8.748597365730395e-5 -1.4465051841850087e-5 -0.00014884257934363377 4.1366796088832764e-5 -8.372924157491907e-5 0.00012277516116547364 9.621835243116104e-5 -9.08923054307638e-5 5.062889731881125e-5 5.039051511949679e-5 2.893124410867128e-5 -8.211568883462165e-5 4.795312388842852e-5 -6.543113949782359e-5 -5.9123040679414315e-5 -0.00013414948143372036 -9.791222705552074e-7 0.00013293843703969074 6.213100760170169e-5 5.950455061202538e-5 -0.0001053601469688963 -9.576167458043671e-5 -6.155306136542106e-5 7.047658191687335e-5 4.0893633289225154e-5 8.267898129225212e-5 0.00018170563026024294 0.00011137554002152816 5.54842493361351e-5 -4.0170197004489914e-5 -0.0001107957820718086 3.716320252670027e-5; -9.160798126471593e-5 6.785205132689274e-5 0.0003041054509616272 6.895263450347476e-5 -0.0001681604122474178 0.0002179226497922705 -5.343732460837617e-7 0.0001335450719146724 -0.0002328296619431425 0.0001069097337698273 -0.00012712172665273113 1.7613386628347965e-5 -9.178571835731474e-5 -0.00012488508273021505 -8.248981666623976e-5 -0.00013304182411741632 -0.00010725095544115937 6.055994472477701e-5 -8.799233964749902e-5 5.5455787698851685e-5 -0.00028113637435821205 -4.0143492331151025e-5 -0.00010672419066180734 8.067678518487788e-5 5.1128230008870005e-5 1.1667703141648272e-6 7.205909733090405e-5 1.1045682551753768e-5 5.647783782694409e-5 -0.00015219208635029776 9.379062921867301e-5 -8.29625574602906e-5; 6.730547791594168e-5 -6.422077096287873e-5 8.415939699404975e-5 5.466757176325948e-6 -6.268448161645258e-5 -5.262660162625315e-5 9.578928764458332e-5 -0.00011775581189650502 7.750304537839606e-5 8.281723019680983e-5 -0.00010448153299730896 0.00011680163477791027 -3.0987689040923406e-5 -0.00010581742791910497 -0.000149979725773302 7.089558819787057e-5 0.00016209626808776865 0.00013271474343272505 -2.4344748834121416e-5 -5.206250915226683e-6 -3.328843051204078e-5 0.0002467458358129007 -8.194584044613593e-5 0.00021488766249749204 0.00027133278088658656 -6.972196244388845e-5 0.00010971995250800067 8.806588046497391e-5 -3.654450157994137e-5 -1.5634836398385904e-5 1.2660807061910804e-5 -5.117932634532757e-5; 0.00012041955192332043 -1.9868912612512026e-5 3.4207479576506036e-5 -1.4939534092930639e-5 -5.3600090221820034e-5 7.297954699666918e-6 -0.00014455125933521566 -5.040984474247021e-5 -0.00010535889164087422 7.772963867980146e-5 1.3392970278134644e-5 -2.990293291869756e-5 9.221483697971976e-6 -7.62492695218787e-5 7.429813253257927e-8 0.00012107264187877879 -0.00013181798426442837 -3.3258997472084096e-5 0.00019372007368589697 -3.9610115389885894e-5 0.00011289992248863839 -7.323481845444824e-5 -6.662681557712172e-5 9.490261577526992e-5 -0.00028579115901150885 0.00010588166474124058 -3.987972236129556e-5 -6.543179774664128e-5 6.961262402117207e-5 -0.00010618784149185893 3.9484960826727016e-5 -7.173168564688186e-5; -4.305355254708908e-5 -0.00017745955632797667 0.00012090674350240088 7.133831455340502e-5 3.122262419984527e-5 4.0891302303325824e-5 -2.9518221383650352e-5 0.00010821549436278859 -1.3144965738219894e-6 6.11615999854733e-5 -2.3301501316917716e-5 7.166859209740007e-5 2.4931452896826947e-5 2.1028172666239545e-5 9.867009280916185e-5 -1.0672099636013819e-5 2.4620100118600952e-5 -0.0001513026924318055 -0.00010959109717313269 8.192152888490817e-6 2.3314127572494615e-5 5.836452178974752e-5 -2.289402702980352e-6 4.46704019463242e-5 -0.00013702147193052022 3.2823200269755614e-5 -0.00024307718344296613 -6.667664787762828e-5 -0.00023315287912062613 0.00018262674198686935 -0.00011498649623043981 -1.566158756993729e-5; 8.848167166033568e-5 -7.875498446481063e-5 0.00012190061791954928 -0.00011280966752056187 -7.751187983049356e-5 3.247536263697524e-5 -0.00010769693214060997 -7.77721030716978e-6 -0.0002767531742978475 -9.601198431882745e-5 8.775952515497078e-6 -4.3750214139762074e-5 5.208036858275426e-5 8.537540528377654e-5 -0.0001687632949728165 0.00012957890688596326 0.00013256005589537255 -4.1578049753655244e-5 -4.2846896726042585e-5 -4.1687130910206214e-5 -0.00014663141806922958 -0.00019980855464782107 -0.0002027522761312872 4.0453530196463245e-5 -2.0210896004212975e-5 -9.392677490212532e-5 -2.865504179116559e-6 4.569040432727743e-5 -6.644894634917192e-5 4.340868584731341e-5 0.00010024031765678628 1.3442527993564695e-5; 7.908121369041828e-5 6.505416421645285e-5 -2.5870389443990915e-5 7.913922664786926e-6 -0.00012442500702377038 -4.7084586193474124e-5 0.00015745438030649337 6.0858697898991275e-5 2.67699724565798e-5 0.00011875782546781071 8.21084413394169e-5 4.319253820652057e-5 -9.433327614601425e-5 -4.5645609142182856e-5 -1.768846051049562e-5 -0.00013342443526995288 3.4658535162204506e-6 6.1024185916947314e-5 0.00029721071222525334 9.108122502879946e-5 2.7543917886990075e-5 4.065936990741599e-5 -0.00011282033484235189 -0.00015146217447977272 -0.00010087749070293833 3.37325454813653e-5 -4.06542221833294e-5 8.823983988917094e-5 0.0001352403287375997 2.774401581632797e-5 -0.00011280213767235731 8.724636608461492e-5; -8.297015882975371e-5 -0.0001709147239960203 5.858106088294263e-5 0.00012766819564445902 -1.4373484547591263e-5 9.67267784365182e-5 8.964245499397714e-5 -8.147833286033486e-5 -0.00011447769647178259 -4.021152780454812e-5 2.9586806347891254e-5 -6.607117385793406e-6 1.6592660911795706e-5 0.00014489590249310114 -5.139958956691387e-5 6.375593657890142e-5 0.00036613226403302496 -0.00011003458929894013 4.907011826459292e-5 -0.00011411565937282086 -6.019510546296834e-5 -0.0001074695595814226 6.851401268091924e-5 -0.00011836581542569247 -0.00013458042521209362 0.00011131985375709942 0.00014241592144412698 0.00021312449244601035 1.6292067455879366e-5 0.00015748572478356854 5.943246477058624e-6 -2.965989769268486e-5; -7.111303839556182e-5 0.00016897642071635267 4.560137258304093e-5 4.641343493640185e-5 0.00020944432198812012 -7.29793307438154e-7 -6.234752491072993e-5 -6.047746958212098e-6 3.957066601069408e-5 6.893284753241347e-5 9.63755695898714e-5 -8.518474769734934e-5 -6.655311572321928e-5 -1.7054641468713515e-5 -1.5606156559807767e-5 -0.00011180237616135934 6.583337690026262e-5 1.2562933441221371e-5 0.00023637917492815057 -0.00021256894670150463 -0.00016168069542109274 -7.642002001594045e-5 0.00016093458125178466 0.00014479128312595884 2.758162643293545e-6 0.0001038471631523853 -7.229934690476589e-5 -3.852064819730068e-5 -0.00024336427723246822 -9.978213766080569e-5 5.534479659909142e-5 -0.00013538980838566272; -1.52063646596112e-5 9.956111192927706e-5 -0.0001984536063532254 -0.00011830584095100192 -1.8563697958092494e-5 -9.006119922509637e-5 0.00015392247609320806 7.550416216013439e-5 2.029306982927215e-5 0.00012704699496730786 -3.3461903294160116e-5 0.0001258480335678017 -4.5054060828483525e-5 -2.943718763568195e-6 5.1783696326815905e-5 4.85138664230805e-5 0.0001617907112091007 -3.6977846102170086e-5 2.2348478742565024e-5 5.1867377115333756e-5 -3.7479778038182165e-5 -9.660340561772198e-5 6.232143663560676e-5 -6.619488682139947e-5 0.00012779216944232367 -3.110098965357754e-5 -0.0001129674126418208 -3.235808959239308e-5 5.327749044276803e-5 2.2579113866009393e-5 2.589712879903863e-5 -1.088926615871412e-5; 0.00011879718441990195 5.919442411754096e-5 -0.00010726307881055542 4.743464307169433e-5 0.00019153426330079288 -0.00021662636789484912 -1.7993550848830795e-5 -3.269047591244504e-5 -5.262581347112703e-6 1.773256519002981e-5 1.3075345683985385e-5 -8.361311700651032e-5 1.635681439141297e-5 4.788497755428875e-5 6.769428743366072e-6 -8.203318646634251e-5 -8.44274598718356e-6 4.529462750427709e-5 -4.275326667327538e-6 -0.0001809084154804299 7.820069514921886e-5 -0.00010685462110201623 -5.2734625918007833e-5 -6.388795448644855e-6 5.012739131119067e-5 -0.00012225217722813098 9.83908226923993e-5 -2.241679465738132e-5 0.00019675672739753297 -8.04352915868259e-5 -0.00011417571148128282 -4.6948602524260016e-5; 4.590549269229136e-5 -3.6525235378596456e-5 -1.7008936089060792e-5 0.00017122503322203634 -0.0001470004690473071 0.0001668905269925381 2.4300544277102796e-5 -1.573736063258587e-5 -4.539899692284829e-5 -9.334399965704143e-5 -4.5098812739559607e-5 -3.0262154728924534e-5 0.00011923974943234604 -0.00019205597801013664 -8.724196687220631e-5 8.788133254163428e-5 -8.823576809485395e-5 -1.4692169321308713e-5 -2.387668337727375e-5 8.178715779101176e-5 5.156584574611881e-5 -3.70039861136634e-5 -6.014439805406982e-5 -5.0733910944507737e-5 -1.7518311329731727e-5 -0.00014513917722618115 9.497303380519642e-5 -1.0040296329116926e-6 9.146245701592987e-5 -0.00012496560976254836 -3.8991304796618364e-5 -0.0002502192410302248; 0.00013547841711675137 -2.21515775296496e-5 3.593700869948356e-5 -1.8062736733049503e-5 9.211813351844709e-5 -0.0001009446214697575 -2.295845495105107e-6 4.4704512158714945e-5 -0.0002059076951513308 4.882838305325809e-5 -3.5499314496922896e-5 0.00012864524517915832 -6.007672442814903e-5 7.320265188587392e-5 -8.936823093558343e-5 -5.018381950088274e-5 -0.00023420127496814902 -2.8770461601597705e-5 -1.5003888681752893e-5 0.0001437549789739577 2.8916166168124327e-5 -5.8869824958893334e-5 7.565689059679818e-5 -0.00013347090959309204 6.491561992566372e-5 -6.339719499686519e-5 -1.3317505424704266e-5 2.7027005610372607e-5 -0.00016094585686681205 -7.474811088054159e-5 5.2005370276262675e-6 -2.707729715205272e-5; -2.551039863381754e-5 5.209115085828229e-5 7.488473972709021e-5 1.6195892451757213e-5 9.858040909686377e-5 4.742460266508413e-5 -3.788732458550232e-5 1.5447423786932653e-5 -0.00010373623620315214 0.00010114345975390856 3.400520115330121e-5 2.530678109261681e-5 6.789630854579241e-6 -4.059838641031418e-5 7.883958648828106e-5 7.265441861174192e-5 7.359699405166386e-6 -1.100286387631008e-6 -8.875291489592514e-5 -4.6747760213369474e-5 5.387849348418657e-5 4.082834318334445e-5 -0.00013919584846651878 -2.0343370058147653e-5 -1.4232112268549204e-5 4.5741811912117224e-5 6.602172841022843e-5 -0.00014557497154517947 -5.3306126507854804e-5 -0.00022214483572232285 0.0001254909166827171 3.18034891032901e-5; -0.00012717302896881464 -0.0001234065130945102 -0.00018226800611110332 2.3183112634155107e-5 -1.0850635672528187e-5 8.681038881085824e-5 -1.998633532885451e-5 -0.0002939084662018154 -8.316095053551703e-5 -8.404789704464353e-5 5.963324284429679e-5 7.651236220162475e-5 0.00023245015464389712 -0.00014850244050713233 5.1012255343534184e-5 0.00011065730351155688 -4.467622792681663e-5 8.462661199636792e-5 1.9409379009897987e-5 -0.00011126875332657211 -2.1146679583312795e-5 -0.00011514203660290661 -0.00017360995124427584 0.0002007771398978496 0.00017670332445895312 6.571316763192467e-5 5.384315946620956e-5 -2.5993577146891654e-7 2.850784395313426e-5 -3.25563379107787e-5 0.0001226944877173939 0.00015630552653355607; 2.0229286856122815e-6 0.00012384489115174145 -0.00015452681522038544 -0.00013188907008771612 -8.832087600058703e-5 -0.00010632311581321819 -1.8993030733334025e-5 -8.290298695829503e-5 4.268656799649065e-5 -0.00018840298993200814 3.450119207808926e-5 -0.00010872505494081158 -3.1409248675124814e-5 -8.263182406160558e-6 -0.00010497915185442952 0.00011493249791055203 0.00016281087470087398 -1.3071300725648577e-5 5.7862240157185035e-5 -2.4616307326148176e-5 -9.055691601862023e-5 2.260330449637395e-6 -0.00015207833906789132 2.689154927786474e-6 0.00017910273332926615 0.0001399228220743613 0.0001391473796156722 -6.6370131553476695e-6 -1.5621524778914292e-5 0.00019278425485207872 7.645251195025803e-5 3.753157385471511e-5; -5.9790109221042585e-5 8.343232883404235e-5 -6.472666456333166e-5 -4.287994335257596e-5 4.8791127197061424e-5 8.209721971567599e-5 -4.5243920171317926e-5 -6.999177304763714e-5 -0.00014049114571619946 6.8179135136828405e-6 1.2311994110374323e-5 0.00015752436564567298 -8.914719393445337e-5 5.064669647751157e-5 0.00017090063164275 -0.00011958264104990943 7.331592467877846e-5 -6.379978759881465e-5 8.714778931377285e-5 0.00015067038446928927 -3.078981200318012e-5 6.970087969008537e-5 -2.0145031443947914e-5 -0.00010271757520470941 -9.425746276477468e-5 3.82264476551995e-5 8.135185606575839e-5 -2.0710344246746028e-5 -3.70430376691678e-6 5.610405395188246e-5 4.3268286104163825e-5 -1.5264084426727218e-5], bias = [-1.5554178070723933e-9, -1.2653319391487635e-9, 2.6884100565617325e-9, -1.542144135194059e-9, 2.0500534203967124e-9, -2.1278473673452427e-9, 1.8581418203550172e-9, -5.244679698200622e-10, -1.8556095478684513e-9, -3.1795024329828967e-9, 5.7434777358519e-10, 1.2718476311827337e-10, -9.74643561101821e-10, -1.3568318841284446e-9, 1.181408361108909e-9, 1.0161816233643573e-9, -1.870464658811178e-9, 3.241058584218752e-9, -1.5556440611216806e-9, -1.2638725728337795e-9, -3.4826047375134603e-9, 1.9590436332696213e-9, 1.973247221801222e-9, -6.966947818256928e-11, 1.018852850180148e-9, -1.159044830543445e-9, -2.405618979718672e-9, -1.878167900408859e-9, 2.0561211994944987e-10, -8.739398781857189e-11, -7.106610400783757e-11, 8.582769322975083e-10]), layer_4 = (weight = [-0.0008350591836811983 -0.000656035359953281 -0.0006829854143568474 -0.0008533712274555558 -0.0007316231649128854 -0.0006736405555528378 -0.0007239386083227933 -0.0006154611028169809 -0.0007016603868907734 -0.0007597888750886944 -0.0006851077287009645 -0.0005239647485397156 -0.0007047561415669539 -0.0006230562456504327 -0.0005532249810320883 -0.000638666969541028 -0.0007889255174090253 -0.0006009826485222614 -0.0006444132117336342 -0.0007532241689014908 -0.0007428007146292102 -0.0007594493997471415 -0.0006545004036955421 -0.0007475317668878131 -0.0009235372342822737 -0.0007013074294870237 -0.0005753863175510075 -0.0005743909896326071 -0.0007876242860287626 -0.0006868440966679428 -0.0006303173375492071 -0.0006761816723946959; 0.000300820036372757 0.00021890497504869938 0.00015581812837027557 0.000260883591407148 2.437864959495959e-5 0.00016770341398229834 0.0002753977230974601 0.00015906483721137433 0.00015090392091282747 0.0002015613474912601 0.0002964774407165407 6.765279294376541e-5 0.00037458803787646143 0.00021831275387102146 0.00015031818837002884 0.0002469661811689891 0.00045142876177339683 1.700329577516754e-5 0.00023873345135893724 5.7109727343027277e-5 0.0002890332259698558 0.0002952993731401426 0.0002898800079143444 0.00028556491869389736 0.0003450518321000143 5.291222984035246e-5 0.0003107012440815659 0.0003261690359640274 0.0002226150029090818 0.00023755316993940252 0.00035985410121475544 0.00039602203599170294], bias = [-0.0006814756834102439, 0.00025066502770358564]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(compute_waveform(
    dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(ax, tsteps, waveform_nn_trained; marker=:circle,
        alpha=0.5, strokewidth=2, markersize=12)

    axislegend(ax, [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb)

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.10.6
Commit 67dffc4a8ae (2024-10-28 12:23 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
  JULIA_CPU_THREADS = 16
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 16
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.