Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux, ComponentArrays, LineSearches, OrdinaryDiffEqLowOrderRK, Optimization,
      OptimizationOptimJL, Printf, Random, SciMLSensitivity
using CairoMakie

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end
one2two (generic function with 1 method)

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end
soln2orbit (generic function with 2 methods)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end
d_dt (generic function with 1 method)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0)<1e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio1 "mass_ratio must be <= 1"
    @assert mass_ratio0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end
compute_waveform (generic function with 2 methods)

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defiing a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32),
    Dense(32 => 2; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32))
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[4.3230684f-6; -0.00011809734; -7.9386955f-5; 4.41291f-5; 8.772751f-5; 0.00015287504; -5.800296f-5; 6.7087814f-5; 4.047994f-5; 0.00024260959; -0.00025093087; 5.7153225f-5; 1.6160737f-5; 0.000114871924; 0.00012876611; -8.7987195f-5; 1.8102062f-5; -5.3517622f-5; 0.00010078511; -4.0580486f-5; -7.076341f-5; 9.738352f-5; -1.6547395f-6; -0.00023471726; 0.00010539557; -1.08210925f-5; 9.6256714f-5; -0.00015666222; 2.1174586f-5; -0.00015400401; 0.00025130785; 2.744245f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[2.213099f-5 0.00012458273 5.8424184f-6 -0.0001590462 -1.0906905f-6 -0.000118128155 -3.215123f-5 3.542986f-5 -3.4116205f-5 -4.291766f-5 0.0002087495 2.8733391f-5 -0.00014070742 2.9546813f-5 -9.952772f-5 -8.629157f-5 2.9022809f-5 3.4923552f-5 9.888075f-5 4.5149784f-5 6.0178834f-5 5.3136897f-5 6.0624574f-5 -1.7383914f-5 -7.220479f-5 -0.00015586847 -0.00028332623 0.00015608036 -0.0001545802 -4.705453f-5 -8.404783f-5 -0.00012049678; -0.0001518781 -0.00017180294 -2.8025825f-5 -0.00011792363 -9.968992f-5 0.00016379033 9.153116f-5 9.076037f-6 0.00021912552 9.2197035f-5 5.4703974f-5 2.0669439f-5 -7.9431404f-5 0.00022201009 -3.9613813f-5 6.100162f-5 -7.3025825f-5 6.4382552f-6 -9.5416435f-6 1.6496826f-5 6.424175f-5 -4.96209f-5 -0.0002031035 -1.1743933f-6 9.603886f-5 -9.620704f-5 0.00012453475 -3.5393146f-5 -0.000104819526 0.00011365106 -9.401721f-5 -9.275985f-5; -0.00017046787 0.00015367623 -9.518876f-6 0.00012217126 0.00012869324 0.00016370331 8.472627f-5 0.00013678378 -4.397512f-5 -7.717175f-5 -8.386407f-5 -4.9747054f-5 7.9842204f-5 -0.0001747744 -0.00018154003 4.7266014f-5 -5.2441792f-5 -1.225351f-5 5.7983556f-7 5.8863454f-5 6.686388f-5 -4.0740768f-5 2.6148495f-5 -2.9836483f-5 -8.778263f-5 2.343884f-5 -0.00010974065 0.00012588699 -0.00015901584 7.688414f-5 -9.662283f-5 5.381661f-5; -0.00016491592 8.7900604f-5 -1.13237365f-5 0.000107449465 -6.4874956f-5 1.8503966f-5 8.055465f-5 -9.759723f-5 0.00010308935 -2.4517089f-5 -0.00015891883 4.114192f-5 0.000103513026 -5.487269f-5 -4.8389174f-5 5.487606f-5 6.718896f-5 4.821045f-5 -1.0819348f-5 -2.885768f-5 -5.4287786f-5 5.915162f-5 -0.0001351579 -0.00022491119 -0.00018327689 -0.00017579723 -0.00011861263 8.820366f-6 -3.2268123f-5 5.926266f-5 -7.349156f-5 3.1019987f-5; -0.00016955238 -1.8783912f-6 -0.000111367306 0.00027685447 1.2545552f-5 -4.1437743f-5 -6.354878f-5 6.6081135f-5 -0.00010453252 -7.74404f-5 2.388579f-5 -8.353213f-5 -0.00021215933 9.011667f-5 -8.461388f-5 -7.4932505f-5 -2.3083881f-5 -3.8703893f-5 5.8963014f-5 9.959298f-5 7.3768024f-5 0.0001788507 -5.0207494f-5 -3.9521958f-5 8.956094f-5 -2.7166638f-5 -0.00013491669 -3.865594f-5 -0.00010148934 -5.8452322f-5 0.00010113227 1.2098868f-5; 0.00013856415 -8.6434055f-5 1.8003513f-7 9.9494035f-5 -6.067965f-5 3.9071503f-5 9.0999965f-6 0.00012597862 -6.6218134f-5 -7.490775f-5 0.00013634538 -7.009828f-5 -2.445192f-5 2.494361f-5 -1.6913362f-5 -3.7218604f-5 4.3235457f-5 -4.3866894f-5 -3.3486118f-5 -0.00010419621 -5.699842f-6 3.1336764f-5 7.355777f-5 -0.00021078237 -2.6785156f-5 5.465725f-5 0.00015908177 8.017996f-5 -5.3888038f-5 0.00014164434 2.6634763f-5 -4.9134087f-5; 0.00010222056 0.00016987845 6.673776f-5 -7.3099363f-6 -0.00011364166 3.2399003f-5 3.684755f-5 -7.6938115f-5 -9.3826435f-5 8.3818304f-5 -8.487319f-5 -9.216597f-5 -8.805891f-6 5.3804797f-5 -3.2835407f-5 -6.408136f-5 3.2701173f-5 -2.1584314f-5 -6.168309f-6 8.128893f-5 0.00035468853 -7.6262084f-5 2.2319628f-5 -6.086259f-5 2.3093547f-5 -5.850942f-5 -4.1717194f-5 5.400848f-5 -0.00015467594 0.00010057838 -1.8977551f-5 2.705174f-5; -2.4375893f-5 -2.0553392f-5 -0.00010103192 -0.00017980921 -5.6599863f-5 -4.782806f-5 -7.581291f-5 -5.6904744f-5 -0.00015418361 1.060334f-5 9.868317f-5 -0.00010093473 -6.890198f-5 -0.00012152795 7.425431f-5 -8.507811f-5 4.9116123f-5 -6.737668f-5 3.8759367f-6 -0.000108261185 3.0770603f-5 8.687193f-5 -1.9415282f-5 0.00010822613 0.00018509655 3.7990914f-5 3.1281925f-5 -8.355929f-6 -0.00021382648 8.47514f-5 -4.0289f-5 -0.00019113894; -2.5156145f-5 3.304198f-5 -9.943499f-5 0.00012463034 -0.00014683469 8.368502f-5 7.1046f-5 4.8219903f-5 -4.576254f-5 6.462355f-5 -0.00010711802 -0.00012869126 2.4238629f-5 0.00014963324 -1.846971f-5 7.189318f-5 5.183445f-5 -5.808645f-5 -6.195654f-5 0.00011486994 0.0001354383 0.00011793157 -0.00010000167 -2.6414858f-5 9.8981f-5 -0.00015920214 -3.9422757f-5 -8.8809524f-5 0.000116907206 0.00022518242 -5.967917f-5 0.00014945152; -5.254459f-5 1.1273664f-5 0.00015919122 -0.00016317262 -3.4698336f-5 -4.2988453f-5 9.986317f-6 4.982742f-5 0.00017173552 5.829386f-5 9.186499f-5 -3.0813295f-5 -9.58539f-5 0.00011463929 -0.00014222843 3.240442f-5 2.6225895f-5 -3.3430853f-5 -4.129353f-5 0.00015822295 7.80244f-6 -4.8887487f-5 0.00021257518 2.6003581f-5 -3.4102777f-5 -4.850925f-6 5.9148028f-5 2.1985374f-5 -4.4489803f-5 -6.572264f-5 0.00013712607 -5.022246f-5; -8.335758f-5 -0.00010985018 8.596642f-5 -0.00010272879 -2.8172115f-6 -0.00012632328 -1.7347997f-5 0.00019754322 4.9469807f-5 0.00016707453 7.5798f-5 -2.6289621f-5 -3.2796368f-5 -6.43487f-5 -5.576864f-6 0.00010697085 0.00011431937 4.688692f-5 -9.7977325f-5 -0.00011372343 -0.0001562482 -0.00020717722 0.00020987215 6.372042f-5 -0.00020596544 8.497464f-5 4.8352666f-7 7.99556f-5 -0.00011421237 -5.5402383f-5 4.016619f-5 -0.00013616563; -6.2740987f-6 1.8719948f-5 1.2537826f-5 -2.738648f-5 -0.00011668224 4.3667704f-5 -3.1882482f-5 -1.2190891f-5 0.00011877023 -0.00016889464 5.1157378f-5 0.000107209904 -2.6515622f-6 0.000101332545 9.981146f-5 3.728192f-5 0.00014604848 -0.00011368475 -6.704358f-5 -7.8658486f-5 -0.00015214198 4.6199624f-5 6.671393f-5 6.505186f-5 0.00010933275 -9.1327194f-5 -6.9407484f-5 9.0768604f-5 3.7978367f-5 4.7990427f-5 5.1415886f-5 5.9888946f-5; 3.9971834f-5 3.9344145f-6 0.000117107236 3.5846482f-5 -8.808366f-5 -4.994297f-5 -1.271278f-5 6.5556575f-5 -6.2773634f-6 -7.373321f-6 7.8695535f-5 -0.00019661916 3.3480206f-5 1.9166764f-6 4.24358f-5 -0.00015950124 -8.5364285f-5 4.1009967f-5 1.4053503f-6 2.0252977f-5 4.487381f-5 8.310869f-5 7.059879f-5 9.307201f-5 3.456652f-5 -3.833794f-5 -8.734389f-5 0.0003562114 9.182422f-5 6.551148f-5 4.1977433f-5 1.9210245f-5; 4.7535894f-5 5.5592638f-5 9.441585f-5 0.0001019951 2.0536425f-5 -3.702007f-7 -0.00018773101 4.199791f-5 5.5873235f-5 -1.154509f-5 -7.024728f-5 -3.3435605f-5 -0.00010801733 5.697357f-6 -6.928141f-5 2.2030788f-5 -0.00010732498 0.00011970033 0.00017217269 -3.4115368f-5 -0.00016860031 0.00018611393 -6.5161665f-5 0.00015060439 0.00021143808 2.4479228f-5 -2.510358f-5 0.00010103037 3.8891154f-5 -8.574502f-5 0.00014259649 -0.000102533384; 2.2956376f-5 -7.524401f-5 0.00013650418 -5.4992477f-5 8.136092f-5 0.00012758718 -3.089742f-5 3.9692502f-5 2.219447f-5 -3.1489362f-5 -9.850625f-5 0.00010679936 0.00015056237 -6.024845f-5 9.8476376f-5 1.6397707f-5 -0.00010653585 0.0001530864 4.6384943f-5 1.116644f-6 -7.672825f-5 4.5727324f-5 1.5821668f-5 -0.00013629337 -3.91535f-5 -8.864622f-5 0.00015019646 0.00010775497 0.00024236739 0.00017345176 -0.00017790652 -0.00010959662; 0.00013813433 -3.5532415f-5 9.685314f-5 7.9335f-5 3.22317f-6 2.2038137f-5 -0.00011547873 -0.000101472455 5.940262f-5 -5.7265162f-5 8.6649765f-5 1.8959952f-5 3.97785f-5 4.6709172f-7 -2.540934f-6 -2.4111869f-5 0.00011610857 -0.00011628713 -9.679029f-5 -9.157744f-6 -5.973531f-5 5.8806723f-5 -2.9909314f-5 4.414481f-5 -5.2263567f-5 0.00012526962 2.2185453f-5 9.6747266f-5 -6.109495f-5 -0.00010336641 4.7942558f-5 0.00017383268; 3.326107f-6 -0.000111588684 -4.693585f-5 -7.235997f-5 5.365333f-5 -0.000122064666 -2.7124941f-5 -3.5430785f-5 -0.00015637728 0.00016759014 1.3415542f-5 0.00019841323 -2.2709937f-5 -0.0002108898 0.00012079433 8.684125f-6 -7.9207285f-7 7.002747f-6 0.00015861951 -0.00013185927 -1.1480658f-5 -0.00027324137 -8.974741f-5 1.3378607f-5 3.562759f-7 -6.092672f-6 -0.0001278043 0.00020232364 2.403064f-6 3.2446907f-5 0.000119513665 0.00024291028; 0.00015129204 -7.57611f-5 6.9398135f-5 4.918363f-5 7.278028f-6 -6.646166f-6 -0.00010933514 -8.809555f-5 -2.698926f-6 -9.5304385f-5 9.7647164f-5 -0.00011025009 -3.2162936f-6 2.8880546f-5 6.857171f-5 0.00016356749 2.3635317f-5 4.1956602f-5 -9.959676f-5 0.00011391361 -0.0002789076 -0.000102154314 -0.000118051335 6.741538f-5 -9.662467f-5 -0.00014471919 -2.8688164f-5 -0.00016265191 -0.0001648567 -0.00014015207 1.0052355f-5 1.2478918f-5; 0.000116246294 0.00023815963 -3.033431f-5 -4.9324997f-5 0.00019932988 -1.7018008f-5 -3.5035653f-5 -0.00011192031 -4.9320846f-5 -0.00013935383 2.7017182f-5 -4.4097173f-6 -4.1282823f-5 4.8186106f-5 -7.853566f-5 6.62885f-5 3.477097f-5 -0.00010656598 2.2808617f-5 -8.618604f-6 -9.710672f-5 -0.00021642463 -5.828555f-5 -0.00012517537 -0.00013631707 3.373854f-6 5.8757094f-5 -1.5718724f-5 3.7178233f-5 0.00010816311 3.0000334f-5 0.000108872046; 2.5587724f-5 4.712529f-5 -4.243206f-5 -0.00016397028 6.194975f-5 0.0001283482 -3.1042364f-5 0.00016763888 -2.6620355f-5 -3.8867325f-5 4.8740083f-5 -9.552233f-5 -0.0002299383 0.000100549005 0.00017980499 -3.2500044f-5 -0.000101419064 3.3198023f-5 0.0001499569 5.911399f-5 -4.5179364f-5 -0.00015022862 -2.4577323f-5 9.23805f-5 -3.105124f-7 8.293009f-5 9.131569f-5 6.5562517f-6 -0.00010617359 0.00010741182 0.00014864464 -0.00013021771; 0.00010607055 -0.00017906743 6.5787106f-5 3.303144f-5 -5.5731985f-6 7.350083f-5 -4.8390884f-5 -8.154755f-6 -6.8817053f-6 5.032589f-5 -2.6136924f-5 -2.7121561f-5 -5.581437f-5 -0.0001340014 -3.9126924f-5 -0.00024977987 5.5918033f-5 5.714492f-6 -0.00011035948 -1.0089237f-5 6.3141604f-5 1.0132442f-5 0.00013333248 -0.00027956808 -4.2473373f-5 0.00012780934 -0.00010531071 -4.822986f-5 -4.268423f-5 4.8655093f-6 -0.00013032745 0.00011052098; -7.056389f-5 5.017567f-6 0.000215456 4.2697416f-6 8.492963f-5 -1.5705942f-5 -1.7550681f-5 0.000104780636 0.00022816125 -5.4760218f-5 -0.00017652006 3.124303f-6 7.400283f-5 -3.5764795f-5 -3.748327f-5 1.3720647f-5 1.505961f-5 -8.13605f-5 -1.9918773f-5 -0.00011632778 -4.3765107f-5 5.7802663f-5 0.00013248678 0.00016607871 9.132428f-5 -7.1035094f-5 0.00011039956 -0.00016212145 0.00011289676 0.00013979823 2.8076242f-5 1.3316545f-5; -7.22545f-5 -9.22493f-6 0.00013955492 4.020493f-5 0.0001360574 -7.4911426f-5 4.949663f-5 6.499754f-5 -5.104655f-5 -2.2835622f-5 0.00019068271 -5.5695706f-5 -4.69663f-5 4.2160333f-5 0.00017374598 5.9210095f-5 1.0253809f-6 -0.0001267822 -0.000107181615 5.75398f-5 -2.6053265f-5 0.00011059913 -7.976586f-5 7.2271505f-5 -0.00013851386 6.131166f-5 -0.00014343236 7.765176f-5 -9.703335f-5 -1.8039007f-5 0.00011173982 0.00016611621; 7.03295f-5 -3.4172874f-5 -0.0001777707 0.000133616 -3.0991505f-5 -0.00014498051 -0.00018550864 -5.4302138f-5 -8.942252f-5 8.379484f-6 4.135159f-5 1.735834f-5 -5.6683144f-5 -1.2388505f-5 6.530579f-5 -0.00010824961 7.980534f-5 -6.79286f-5 0.0001711268 8.11928f-5 6.799788f-5 0.0001782986 -7.005856f-5 4.9356204f-6 -7.701407f-6 -5.2451738f-5 -8.547833f-5 4.02078f-6 3.4063163f-5 5.527918f-6 -2.9865085f-5 9.2751936f-5; -0.00010293183 -5.0927174f-5 -7.218729f-5 -0.00013717754 -8.8056564f-5 -5.566165f-5 0.0001420178 -6.8058f-5 -0.0001036334 -4.779835f-5 -5.206634f-6 0.00018317634 -1.509813f-5 -4.9318412f-5 -3.0791813f-5 0.00014799218 5.75135f-5 -7.496745f-5 3.7761292f-5 -2.495714f-5 8.03455f-5 -5.0305593f-5 0.00021378824 -7.234274f-5 2.015586f-5 0.0001740774 -9.523353f-5 0.00016451938 -5.537625f-5 5.8579694f-6 5.7439505f-5 -0.000120929966; -8.434897f-5 0.0001285136 9.349221f-5 -0.00011254757 0.00013492811 2.2353099f-5 4.7741276f-5 6.172222f-5 2.1141122f-6 2.414677f-5 3.957167f-5 5.194182f-5 0.00013160899 -6.5248954f-5 -2.0266352f-5 5.4837834f-5 8.85645f-5 5.9524114f-5 -5.692479f-6 0.00010409628 1.7489674f-5 2.1483438f-5 0.0001452799 0.00012262902 -5.5465145f-5 -3.2055254f-7 -0.000100540194 1.7635943f-5 -2.8534587f-5 -6.493419f-5 -4.9738108f-5 -5.400979f-5; 8.395551f-5 -0.00012259817 -0.00012117672 1.7817021f-5 -0.00015437191 6.6292755f-6 -8.127832f-6 -5.784724f-6 7.819672f-5 7.681014f-5 0.00014200163 -9.3703255f-7 8.566589f-5 4.8651247f-5 -0.00026464558 0.00014326583 -0.0001063841 5.1172137f-5 5.26562f-5 8.4693675f-5 8.089853f-5 -4.8063653f-6 -5.0786744f-5 6.4259366f-5 -0.0001951248 5.390363f-5 2.2541142f-5 -0.0001390454 0.00010460742 -4.493794f-5 0.00013454963 -0.00016920351; 6.363425f-5 -3.3597144f-7 -3.9061237f-5 -6.829411f-5 0.0001606423 -8.1037826f-5 3.7290527f-5 -5.2216477f-5 -0.00013315909 6.127417f-5 4.921323f-5 -0.00011884563 4.7052257f-5 0.00013241825 -4.8264836f-5 -5.3466105f-5 0.00014495026 -0.00015753687 4.3423664f-5 -2.5807016f-5 1.8562632f-5 9.406546f-5 -0.00016784236 6.329863f-5 -6.88766f-5 3.9628114f-5 8.383034f-5 -0.00011808592 1.5173725f-6 2.4119401f-5 -2.8280074f-5 -7.3122287f-6; 9.64679f-5 9.926272f-5 -6.1527408f-6 6.323405f-5 6.924189f-5 0.00010019606 2.4677347f-5 -8.8689f-5 6.637421f-6 3.5229532f-5 -6.373331f-5 -0.00015805567 3.8725113f-5 -8.9776535f-5 -0.00017137219 -0.00010658753 0.00013598369 7.03612f-5 -4.7139434f-5 7.3236726f-5 7.5839453f-6 -1.3542715f-5 4.0874053f-5 -1.9153356f-5 3.818348f-5 0.00011311683 0.00016528956 5.2691834f-5 4.406619f-5 0.00011600675 2.7346492f-5 3.0911287f-5; -7.469636f-5 -7.643565f-5 9.485811f-5 -0.00012061365 -1.993382f-5 2.2050841f-5 -0.00015175773 -2.4226503f-5 0.00018644844 5.843065f-5 4.318866f-6 -0.00018078584 5.7793884f-5 -1.8209905f-5 -4.1155385f-5 -0.00016012872 -0.00012032888 3.137158f-6 -0.000121378005 2.0845853f-5 8.6732914f-5 3.805188f-5 -0.00014503243 -9.7330005f-5 -8.5914384f-5 -2.3565432f-5 -0.00023840251 -9.5089636f-5 9.6606364f-8 8.584818f-6 8.640262f-5 8.665559f-6; -0.000107139036 -0.00011218073 -8.7097105f-6 -0.00015154037 -2.8480059f-5 -0.000111963855 -7.8679244f-5 -3.2007403f-5 8.201816f-5 -1.2234173f-6 5.2671247f-5 -5.802831f-5 7.153545f-5 -1.6344951f-5 7.524136f-5 -0.00016741484 8.173474f-5 8.8912624f-5 -6.6302775f-5 8.442736f-5 9.397374f-5 4.654537f-5 -9.622006f-6 -0.00015923945 5.8996196f-5 8.532837f-5 -2.8180333f-5 -4.6843823f-5 -1.3289656f-5 -0.00012736522 -0.00011598052 -1.6522208f-5; 2.9919076f-5 0.00010960125 -2.8032233f-5 4.9026832f-5 -5.1874704f-5 -4.1841635f-5 2.5186066f-5 -7.508696f-5 1.2514532f-5 8.2808394f-5 0.00013176417 -0.00021883748 -7.6660894f-5 2.8924138f-5 -0.00031680628 -3.1148704f-6 -6.0548762f-5 -2.6569276f-5 -8.0386024f-5 4.4382337f-5 -1.7699254f-5 0.00015025263 -3.1982938f-5 0.00024471775 -1.9547906f-5 -9.921873f-6 -0.00018200823 3.2237793f-5 6.3107735f-5 7.1539066f-6 0.0002744226 3.7988393f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[0.00012306418 1.1172839f-5 -3.743746f-5 -4.0265833f-5 -9.091714f-6 0.00011344894 -0.0001878203 -2.1620512f-5 0.000117312375 -5.2965617f-5 -1.8147355f-5 0.00020390513 -0.000101870115 -0.0001973662 -0.00012553546 -5.0706567f-5 6.1621868f-6 -0.0001490198 -0.0001741599 -4.9238733f-5 5.1763686f-6 -5.2879117f-5 -0.00014994782 5.0918785f-5 -2.5663052f-5 2.3836523f-5 0.00015615905 9.518999f-5 0.0001782519 4.0592247f-5 0.00026212962 -8.658997f-5; 0.00015341247 4.212888f-5 2.3445646f-5 -0.0001280238 -3.1255593f-5 -8.749178f-5 -6.7706256f-5 0.00013097301 -4.9825117f-5 1.5614603f-5 -0.00011160552 -2.3082392f-5 -8.5393665f-5 3.3719753f-5 -1.50378055f-5 -9.308439f-5 -4.8978007f-5 -0.00013926662 0.00014879026 -6.051587f-5 5.5755892f-5 -5.2332587f-5 2.9146922f-5 0.00028751796 0.00010468368 5.4045635f-5 0.000257395 -7.188611f-5 7.2825875f-5 0.0001407158 -0.00017371657 -1.9478895f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most DL frameworks, Lux defaults to using Float32, however, in this case we need Float64

julia
const params = ComponentArray(ps |> f64)

const nn_model = StatefulLuxLayer{true}(nn, nothing, st)
StatefulLuxLayer{true}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),  # 64 parameters
        layer_3 = Dense(32 => 32, cos),  # 1_056 parameters
        layer_4 = Dense(32 => 2),       # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)

    axislegend(ax, [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"]; position=:lb)

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end
loss (generic function with 1 method)

Warmup the loss function

julia
loss(params)
0.0007182128284356028

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end
callback (generic function with 1 method)

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob, BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback, maxiters=1000)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [4.323068424125576e-6; -0.00011809734132826039; -7.93869548941864e-5; 4.412910129755343e-5; 8.77275087986608e-5; 0.0001528750435680508; -5.8002959121896384e-5; 6.708781438639715e-5; 4.047993934362506e-5; 0.00024260958889503332; -0.0002509308687878812; 5.715322549802053e-5; 1.616073677722782e-5; 0.00011487192386986277; 0.00012876611435785245; -8.79871950018391e-5; 1.810206231312067e-5; -5.3517622291114366e-5; 0.000100785109680102; -4.058048580190113e-5; -7.076340989443208e-5; 9.738351945993427e-5; -1.6547395489373446e-6; -0.00023471725580703602; 0.00010539557115396794; -1.0821092473630182e-5; 9.625671373195129e-5; -0.00015666222316189552; 2.117458643620127e-5; -0.00015400401025525616; 0.0002513078507033911; 2.744244920903336e-5;;], bias = [2.682919702824844e-18, -1.4799383518759826e-16, -1.157429086870016e-16, 3.424433818171557e-17, 5.1447972768424805e-17, 1.661177667193009e-16, -5.855639887188296e-17, 8.032873146609223e-17, 1.7457725305138043e-17, -4.3086390714153174e-17, 2.0810831567653747e-16, 2.244786993717972e-17, 1.0720048335458235e-17, 8.75280562174714e-17, 7.388096535180597e-18, -8.117525086212224e-17, -2.378373510075908e-18, 4.6761221682442735e-18, 5.04222444413382e-17, -1.0649640688868747e-17, 4.7336768741554674e-18, 5.5197107740733994e-17, -2.6191256199023075e-18, -4.974048927568279e-16, 1.2253932117700457e-16, -4.381749427817379e-18, 2.1412467165377497e-16, -2.6478381977682907e-16, 4.9493178776454073e-17, -2.1605525884946086e-16, 2.302014426774637e-16, 1.535355262770181e-17]), layer_3 = (weight = [2.212840972001279e-5 0.00012458015025127526 5.839838914972718e-6 -0.0001590487847430747 -1.0932700057741734e-6 -0.00011813073452561463 -3.215381021496015e-5 3.5427281355941554e-5 -3.411878443207842e-5 -4.2920241147476664e-5 0.00020874691453300626 2.8730811928970526e-5 -0.00014070999551100093 2.954423306742465e-5 -9.953030185773907e-5 -8.629414897645564e-5 2.9020229513976855e-5 3.492097293142081e-5 9.887817194917698e-5 4.514720402889636e-5 6.0176254729731314e-5 5.313431737827541e-5 6.062199444508724e-5 -1.738649389936106e-5 -7.220737116644551e-5 -0.00015587105391082735 -0.0002833288055704961 0.00015607778531499933 -0.00015458277285565966 -4.70571105403294e-5 -8.40504109952204e-5 -0.00012049935704331238; -0.00015187846542915004 -0.00017180329700579373 -2.8026183597944372e-5 -0.00011792398686168361 -9.969027700947177e-5 0.00016378997429626662 9.153080213934553e-5 9.075677782561725e-6 0.000219125159903228 9.219667595232917e-5 5.470361543473917e-5 2.0669079505637237e-5 -7.94317627406832e-5 0.00022200972785128763 -3.961417200855389e-5 6.10012623400283e-5 -7.302618428013884e-5 6.437896219608511e-6 -9.542002542843294e-6 1.6496466971033943e-5 6.424139178477617e-5 -4.962125853864968e-5 -0.00020310385548157292 -1.174752354125864e-6 9.603849798793431e-5 -9.620739979084249e-5 0.00012453439308217235 -3.539350541188699e-5 -0.00010481988533513284 0.00011365070281670102 -9.401756937960961e-5 -9.276021114430981e-5; -0.00017046798862753476 0.00015367611225923532 -9.51899224634067e-6 0.00012217114303024454 0.00012869312226070076 0.00016370319675883323 8.472615476461099e-5 0.00013678366698553762 -4.3975236312291e-5 -7.717186932455568e-5 -8.386418599871724e-5 -4.974716981411009e-5 7.984208818055648e-5 -0.00017477451152404764 -0.00018154014802308736 4.7265897901924215e-5 -5.2441908028014394e-5 -1.2253626354832976e-5 5.797194516517318e-7 5.886333778147551e-5 6.68637661332821e-5 -4.07408839778882e-5 2.6148378620235768e-5 -2.9836599114386567e-5 -8.778274755928667e-5 2.3438724693222493e-5 -0.0001097407654556819 0.00012588687272040223 -0.00015901595148494504 7.688402578463825e-5 -9.66229487490262e-5 5.381649347367658e-5; -0.00016491900011739188 8.789752767347757e-5 -1.1326812700878689e-5 0.00010744638869231334 -6.487803172627124e-5 1.850089004598124e-5 8.055157534703521e-5 -9.760030504390576e-5 0.00010308627109201362 -2.4520164761451843e-5 -0.00015892190846916049 4.1138844362412506e-5 0.00010350995010388615 -5.4875765518090214e-5 -4.839225063755819e-5 5.4872985610695265e-5 6.71858813163785e-5 4.820737533895158e-5 -1.0822424219670928e-5 -2.8860755510560988e-5 -5.429086219943056e-5 5.9148542031604364e-5 -0.00013516097375965172 -0.0002249142619532029 -0.0001832799664682641 -0.00017580031114472494 -0.0001186157080794635 8.817289862929929e-6 -3.2271198752153756e-5 5.9259584058747366e-5 -7.349463546251622e-5 3.1016910479353955e-5; -0.00016955414424764005 -1.880153047641869e-6 -0.0001113690679066035 0.0002768527036661105 1.2543789909719059e-5 -4.143950458262663e-5 -6.355054314361708e-5 6.607937290219834e-5 -0.00010453428070640113 -7.74421652184652e-5 2.3884028325619486e-5 -8.353389524735338e-5 -0.00021216108974648638 9.011490952809147e-5 -8.461564096967691e-5 -7.49342663201647e-5 -2.3085643204366752e-5 -3.870565445857906e-5 5.896125263654259e-5 9.959121858820398e-5 7.376626205051589e-5 0.0001788489451993123 -5.0209256312782095e-5 -3.952371947694242e-5 8.955917916147655e-5 -2.716839948142422e-5 -0.0001349184468725831 -3.865770225991044e-5 -0.00010149110415831199 -5.84540840808355e-5 0.00010113050562918041 1.2097106138854563e-5; 0.00013856502429857238 -8.643318195086975e-5 1.8090771672394373e-7 9.94949072440647e-5 -6.067877673618391e-5 3.907237571123653e-5 9.100869119123024e-6 0.00012597949482309826 -6.621726113795029e-5 -7.490687910892467e-5 0.00013634625012793998 -7.009740564231531e-5 -2.4451047084454107e-5 2.4944483283743703e-5 -1.6912489246042957e-5 -3.721773153361761e-5 4.323632985628521e-5 -4.386602138825451e-5 -3.348524531307594e-5 -0.00010419533699010226 -5.698969362726727e-6 3.133763622767347e-5 7.355864310679346e-5 -0.0002107814949164755 -2.678428298227945e-5 5.46581208430728e-5 0.00015908264269395795 8.018083410146279e-5 -5.388716506845941e-5 0.00014164521364303892 2.6635635930983753e-5 -4.9133214434419554e-5; 0.00010222142133380939 0.00016987930699934514 6.673862059352465e-5 -7.309076532408183e-6 -0.0001136408014969577 3.239986232466065e-5 3.684840827582964e-5 -7.693725514008137e-5 -9.382557494581069e-5 8.381916392344921e-5 -8.487232720263476e-5 -9.21651068308055e-5 -8.80503160777786e-6 5.3805656866014995e-5 -3.2834547378744874e-5 -6.408049979262548e-5 3.2702032844384133e-5 -2.158345399556107e-5 -6.167449224167292e-6 8.128978639825855e-5 0.00035468938709919855 -7.626122409027611e-5 2.2320487336685905e-5 -6.086172892146698e-5 2.3094407301241243e-5 -5.850856053774526e-5 -4.17163343042282e-5 5.4009340023463274e-5 -0.00015467508513728444 0.00010057923770028158 -1.897669120417622e-5 2.7052599183946662e-5; -2.4379503411854704e-5 -2.055700268167215e-5 -0.00010103552766007015 -0.0001798128227899146 -5.6603473677322667e-5 -4.783167012711491e-5 -7.58165201484504e-5 -5.690835449125473e-5 -0.00015418722021489595 1.0599730121009082e-5 9.867956261527279e-5 -0.00010093834269422105 -6.890559005181658e-5 -0.00012153155828920892 7.425070227262264e-5 -8.50817172983888e-5 4.911251246017991e-5 -6.738028740140657e-5 3.872326505437895e-6 -0.00010826479545424883 3.0766992967317064e-5 8.686831872116572e-5 -1.941889193470973e-5 0.00010822251944712855 0.00018509293888134868 3.798730405766147e-5 3.12783145266376e-5 -8.35953913402343e-6 -0.00021383009287159288 8.474779180623196e-5 -4.029261024757713e-5 -0.00019114255108501125; -2.515409211826191e-5 3.304403183375352e-5 -9.94329373226557e-5 0.00012463239657319812 -0.00014683263613749948 8.368707606660377e-5 7.104805375076465e-5 4.822195590333572e-5 -4.5760486291491154e-5 6.462560234447637e-5 -0.00010711597021347487 -0.00012868920636736214 2.4240681550347492e-5 0.00014963529359175742 -1.84676562368528e-5 7.189522987556921e-5 5.183650248893371e-5 -5.808439779701045e-5 -6.195448692795703e-5 0.0001148719904720758 0.00013544035139586303 0.00011793362612462762 -9.999961800547461e-5 -2.6412805500710943e-5 9.898305254847796e-5 -0.00015920009061135503 -3.942070431565689e-5 -8.880747079292215e-5 0.00011690925860404797 0.00022518447188286266 -5.967711583273141e-5 0.00014945356927439379; -5.2542811377429456e-5 1.1275443225228051e-5 0.00015919299397362121 -0.00016317084015602274 -3.46965571821144e-5 -4.2986674148359815e-5 9.988096221125556e-6 4.982919921028256e-5 0.0001717372978732548 5.829563809533091e-5 9.186676828397621e-5 -3.081151595827292e-5 -9.585212325466724e-5 0.00011464106858917769 -0.00014222664691763864 3.2406198383314666e-5 2.6227673456082816e-5 -3.342907445178445e-5 -4.1291750736064604e-5 0.0001582247240862436 7.804219362216856e-6 -4.888570771575432e-5 0.00021257695692336275 2.6005360209128654e-5 -3.410099822352262e-5 -4.849146139826976e-6 5.91498064154342e-5 2.1987152590063162e-5 -4.448802434628766e-5 -6.572086209920583e-5 0.0001371278521532396 -5.022068222890562e-5; -8.335888831569047e-5 -0.00010985148672405207 8.596511635771123e-5 -0.00010273009227747178 -2.8185165607092484e-6 -0.00012632458945660964 -1.73493017260453e-5 0.0001975419190952557 4.9468501873360004e-5 0.00016707322085353067 7.579669638148202e-5 -2.6290926126949748e-5 -3.2797673140141776e-5 -6.435000521521623e-5 -5.578169276827725e-6 0.00010696954724919213 0.00011431806798865867 4.6885613317880345e-5 -9.797863002093243e-5 -0.00011372473362060358 -0.00015624951173011167 -0.00020717852292507653 0.0002098708400975896 6.371911605896721e-5 -0.00020596674128826406 8.497333785127689e-5 4.822215541223052e-7 7.995429497753208e-5 -0.00011421367797227722 -5.5403687821646256e-5 4.016488402926064e-5 -0.00013616693739239261; -6.27259368280206e-6 1.872145258806764e-5 1.2539330609289682e-5 -2.738497568366817e-5 -0.00011668073529159229 4.3669208761850734e-5 -3.1880977249381366e-5 -1.2189385589006203e-5 0.00011877173420493274 -0.00016889313082227671 5.1158883079016454e-5 0.00010721140898517038 -2.6500571862518663e-6 0.00010133405049549569 9.981296790050333e-5 3.728342454400331e-5 0.0001460499832405227 -0.0001136832444813596 -6.704207152375651e-5 -7.865698097753556e-5 -0.000152140471246237 4.620112923425025e-5 6.67154370689285e-5 6.505336824322114e-5 0.00010933425693059596 -9.132568920045304e-5 -6.940597922090098e-5 9.077010946333015e-5 3.797987193087886e-5 4.799193157799189e-5 5.141739057709068e-5 5.989045058362941e-5; 3.9974327770729605e-5 3.936908521094558e-6 0.00011710973030534721 3.584897621880786e-5 -8.808116563483106e-5 -4.994047576799357e-5 -1.2710285594545636e-5 6.555906856778372e-5 -6.274869341318832e-6 -7.370826827906066e-6 7.869802920721699e-5 -0.0001966166690213473 3.348270019746091e-5 1.9191704236023274e-6 4.243829443697135e-5 -0.00015949874960239014 -8.536179102843927e-5 4.101246140311898e-5 1.4078443168157248e-6 2.025547127818282e-5 4.487630412444725e-5 8.31111812626094e-5 7.060128536654362e-5 9.307450654559701e-5 3.456901522116495e-5 -3.833544706463166e-5 -8.734139719632545e-5 0.00035621390834096427 9.182671619459678e-5 6.551397218247548e-5 4.197992729713422e-5 1.921273924057969e-5; 4.753785086163e-5 5.559459512536138e-5 9.441780983419368e-5 0.00010199705846556952 2.0538381961012555e-5 -3.682436608167857e-7 -0.0001877290562498809 4.1999868500934455e-5 5.587519243076022e-5 -1.1543133239863688e-5 -7.024532553095983e-5 -3.343364752996306e-5 -0.00010801537447892464 5.699314244188178e-6 -6.927944943360497e-5 2.203274542065457e-5 -0.00010732302345618943 0.00011970228641243152 0.00017217464463268372 -3.411341114602474e-5 -0.00016859835179107902 0.00018611588873849257 -6.515970764499476e-5 0.00015060634885394837 0.0002114400338376397 2.4481185193365917e-5 -2.5101622866862226e-5 0.00010103232469357527 3.889311099723566e-5 -8.574306431744217e-5 0.00014259844445568817 -0.00010253142701745397; 2.295885333510857e-5 -7.524153242926027e-5 0.00013650666023963835 -5.498999944179199e-5 8.13633998426202e-5 0.00012758965404130777 -3.089494184464209e-5 3.969498011642982e-5 2.2196947795559246e-5 -3.148688465234628e-5 -9.85037730039949e-5 0.00010680183568397552 0.00015056484353622677 -6.0245972961325585e-5 9.847885321718514e-5 1.6400185088376783e-5 -0.00010653337430780504 0.00015308887323258494 4.63874204818664e-5 1.1191215911405673e-6 -7.672576957488557e-5 4.572980124280302e-5 1.5824145705078705e-5 -0.00013629089134855626 -3.9151021881165565e-5 -8.864374449116842e-5 0.00015019893562780225 0.00010775744540509003 0.0002423698644655408 0.00017345423493892183 -0.0001779040429007184 -0.00010959414412140629; 0.0001381357431196303 -3.553100416729183e-5 9.685454985766133e-5 7.933640860893425e-5 3.224581425009321e-6 2.2039548431621724e-5 -0.00011547732013103956 -0.00010147104351607608 5.940403186424309e-5 -5.726375114001476e-5 8.665117618100389e-5 1.896136361348042e-5 3.977991197781373e-5 4.685030514842711e-7 -2.5395226916736684e-6 -2.41104573692535e-5 0.00011610997878188931 -0.0001162857154017632 -9.678887747986025e-5 -9.156332660940017e-6 -5.97338987322386e-5 5.880813457361669e-5 -2.990790219745385e-5 4.414622141801781e-5 -5.2262156012596703e-5 0.00012527103515050115 2.218686474544253e-5 9.674867739840902e-5 -6.109353874170495e-5 -0.00010336499711092963 4.79439693403769e-5 0.00017383409410423313; 3.325710699149158e-6 -0.0001115890806119428 -4.693624547659562e-5 -7.236036950932081e-5 5.3652933578937004e-5 -0.00012206506259250379 -2.71253372062947e-5 -3.543118115491399e-5 -0.00015637767834565804 0.00016758974773833127 1.3415145380661127e-5 0.00019841282877162847 -2.2710333419733084e-5 -0.00021089020052901522 0.00012079393157758736 8.683728387345252e-6 -7.924690891435662e-7 7.002350989145224e-6 0.00015861911773239171 -0.00013185966569441573 -1.1481054266974522e-5 -0.0002732417666863195 -8.974780545546896e-5 1.3378210800821123e-5 3.558796691641963e-7 -6.0930680462356e-6 -0.00012780470175645416 0.0002023232485356284 2.4026677095513244e-6 3.24465111901526e-5 0.00011951326845004807 0.00024290987888374153; 0.00015128852759255514 -7.576461281392584e-5 6.939461967443317e-5 4.918011403526815e-5 7.274513097521941e-6 -6.6496811658418595e-6 -0.00010933865338563122 -8.809906354405921e-5 -2.702440963615484e-6 -9.53079004491899e-5 9.764364880284427e-5 -0.00011025360505561491 -3.2198086163363106e-6 2.887703088812583e-5 6.856819458075307e-5 0.00016356397030643552 2.363180213899337e-5 4.195308723604115e-5 -9.960027160819913e-5 0.00011391009251122034 -0.0002789111101836799 -0.00010215782891724629 -0.00011805485042975802 6.741186301696283e-5 -9.66281884419564e-5 -0.00014472270482917822 -2.86916790745307e-5 -0.00016265542286677802 -0.00016486021264688574 -0.0001401555862347284 1.0048839828599394e-5 1.2475402615678366e-5; 0.00011624546779292291 0.00023815880150740074 -3.0335135748299635e-5 -4.932582349776581e-5 0.00019932905298863403 -1.7018834555905223e-5 -3.503647938980042e-5 -0.00011192113423514404 -4.9321672563947004e-5 -0.00013935465220620542 2.701635595676451e-5 -4.41054361128387e-6 -4.1283649339148414e-5 4.8185279878983406e-5 -7.853648684014462e-5 6.628767324817232e-5 3.4770142078359054e-5 -0.0001065668079585702 2.2807791096586046e-5 -8.619429977825575e-6 -9.710754646714054e-5 -0.0002164254515131851 -5.828637630240404e-5 -0.00012517619733026035 -0.0001363179003287018 3.373027779688798e-6 5.875626767609175e-5 -1.5719550433835538e-5 3.717740720692691e-5 0.00010816228633719796 2.9999507673526788e-5 0.00010887121926732968; 2.5588929769983545e-5 4.7126495541701276e-5 -4.243085475133043e-5 -0.00016396907647100347 6.195095722836145e-5 0.00012834940373471028 -3.1041158013577927e-5 0.00016764008416834118 -2.6619149298822183e-5 -3.886611937158239e-5 4.874128883276476e-5 -9.552112367197065e-5 -0.00022993709915277553 0.00010055021068586455 0.00017980619834309734 -3.249883791392037e-5 -0.00010141785803676594 3.319922855047401e-5 0.0001499581101820113 5.911519640398151e-5 -4.517815812101779e-5 -0.00015022741188255504 -2.4576116789430343e-5 9.238170407517531e-5 -3.0930657482204613e-7 8.293129742418413e-5 9.131689678218789e-5 6.557457549597376e-6 -0.00010617238347145744 0.00010741302487479534 0.00014864584757052698 -0.00013021650572733046; 0.0001060678013199715 -0.0001790701745594953 6.578436012564395e-5 3.3028693800538454e-5 -5.575944320343542e-6 7.34980830056575e-5 -4.8393630200783725e-5 -8.157500920835808e-6 -6.884451173441625e-6 5.032314300842768e-5 -2.613967000433624e-5 -2.7124307158422747e-5 -5.581711513155509e-5 -0.00013400414456195349 -3.912966996362282e-5 -0.0002497826163708389 5.591528759754831e-5 5.711746220136779e-6 -0.00011036222269108008 -1.0091982509088834e-5 6.313885831691376e-5 1.012969631479107e-5 0.00013332973038263223 -0.0002795708234007206 -4.24761193390131e-5 0.0001278065946134514 -0.00010531345753047184 -4.823260598280539e-5 -4.2686976590681625e-5 4.862763384122806e-6 -0.0001303301934203728 0.0001105182353821364; -7.056119681784784e-5 5.020258985470796e-6 0.0002154586978840411 4.272433333411371e-6 8.493231837653917e-5 -1.5703250365765546e-5 -1.754798938329057e-5 0.00010478332808723382 0.00022816394596610176 -5.475752584158811e-5 -0.00017651736595423683 3.1269947876192553e-6 7.400552533816625e-5 -3.5762103262665916e-5 -3.748057894975357e-5 1.3723338842262847e-5 1.5062301489692916e-5 -8.135780842259672e-5 -1.991608075606954e-5 -0.00011632508573804092 -4.376241519914343e-5 5.780535468964113e-5 0.0001324894689141597 0.0001660814010655783 9.132696834874404e-5 -7.103240238485767e-5 0.00011040225183840137 -0.00016211875515276738 0.00011289945507904577 0.00013980091864956441 2.8078933272740828e-5 1.3319236703850694e-5; -7.225267288822885e-5 -9.223102077623895e-6 0.0001395567469639832 4.02067580775978e-5 0.00013605922324267487 -7.490959795042601e-5 4.949845790315018e-5 6.49993690071767e-5 -5.104472254834436e-5 -2.2833793972078503e-5 0.00019068453732093576 -5.569387756486819e-5 -4.6964471019694886e-5 4.2162160772466473e-5 0.00017374781056922605 5.921192316372217e-5 1.0272090289585864e-6 -0.00012678036544321478 -0.00010717978691764983 5.7541628867962204e-5 -2.6051437069787973e-5 0.00011060096041682343 -7.976402955910393e-5 7.227333300894371e-5 -0.00013851203218883765 6.131348529515577e-5 -0.0001434305358802789 7.765359079847313e-5 -9.703152134398227e-5 -1.8037178941557478e-5 0.00011174164867175508 0.0001661180376881895; 7.032891566976044e-5 -3.417345892759633e-5 -0.00017777127818791936 0.00013361542055261434 -3.0992090116541525e-5 -0.0001449810983578716 -0.0001855092300067735 -5.430272308532923e-5 -8.942310837685029e-5 8.378898699258276e-6 4.1351006573488604e-5 1.735775417526386e-5 -5.668372928286825e-5 -1.2389089945116111e-5 6.530520163122021e-5 -0.00010825019440036977 7.980475179263468e-5 -6.792918743687556e-5 0.00017112621257891212 8.119221142604119e-5 6.799729139655065e-5 0.00017829800758389078 -7.005914399289455e-5 4.935035165354643e-6 -7.70199262862509e-6 -5.245232337276377e-5 -8.547891361760659e-5 4.020194908466225e-6 3.406257796794256e-5 5.527332675941866e-6 -2.986567001444875e-5 9.275135066471148e-5; -0.0001029319695254961 -5.0927313498657103e-5 -7.218743245463659e-5 -0.00013717768284838506 -8.805670346479698e-5 -5.5661790032142246e-5 0.00014201766192566752 -6.80581373336896e-5 -0.00010363353918652844 -4.7798488015509885e-5 -5.206773369686464e-6 0.00018317619709417373 -1.5098269034904492e-5 -4.931855197643669e-5 -3.079195208844521e-5 0.00014799203803416507 5.751336233227647e-5 -7.496759041093146e-5 3.776115285002805e-5 -2.4957279823109045e-5 8.034535734335113e-5 -5.030573207765642e-5 0.0002137881007336838 -7.234287601310376e-5 2.0155720118812246e-5 0.0001740772628516717 -9.523367167563937e-5 0.00016451924478754642 -5.537639059472696e-5 5.857829943882602e-6 5.743936584334307e-5 -0.00012093010569803281; -8.43461393661475e-5 0.0001285164244190053 9.349504308274798e-5 -0.00011254473883227197 0.00013493093775550558 2.2355928202962136e-5 4.7744105367112234e-5 6.172504709766263e-5 2.1169416949343636e-6 2.414959909296754e-5 3.9574500086784995e-5 5.1944650302996276e-5 0.00013161182051490467 -6.524612489657326e-5 -2.0263522848027493e-5 5.484066333237754e-5 8.856732802897159e-5 5.952694392263123e-5 -5.689649689069857e-6 0.00010409911102115452 1.7492503510676015e-5 2.1486267550150177e-5 0.00014528273369151596 0.00012263184817934444 -5.5462315677782836e-5 -3.1772309839134455e-7 -0.00010053736422478643 1.763877208657605e-5 -2.8531757371724447e-5 -6.493135969423572e-5 -4.9735278472469656e-5 -5.4006960445890344e-5; 8.395528800439695e-5 -0.00012259838752095816 -0.00012117694188953659 1.781680236864989e-5 -0.0001543721306242169 6.629056655691962e-6 -8.12805103658535e-6 -5.78494302852257e-6 7.819649852000794e-5 7.680991927747706e-5 0.00014200141538287634 -9.372514008485916e-7 8.56656691886359e-5 4.865102835003054e-5 -0.0002646458020771036 0.0001432656130183368 -0.00010638431615542014 5.117191847067951e-5 5.2655980993329805e-5 8.469345573222882e-5 8.089830896471336e-5 -4.806584116415149e-6 -5.078696324370196e-5 6.425914705152688e-5 -0.0001951250166038169 5.390341118447157e-5 2.2540923395461593e-5 -0.00013904561694824287 0.00010460719874947008 -4.4938159249687285e-5 0.00013454940869826016 -0.00016920373366327343; 6.3633841734139e-5 -3.3638170248599784e-7 -3.9061647015164953e-5 -6.829452021816559e-5 0.00016064189267739915 -8.103823628812559e-5 3.729011669451576e-5 -5.221688760980743e-5 -0.00013315950044967052 6.127376119070692e-5 4.9212820983259615e-5 -0.00011884604018019307 4.705184701992365e-5 0.0001324178419115428 -4.8265245907945475e-5 -5.346651514014897e-5 0.000144949849442804 -0.0001575372762938569 4.342325419815553e-5 -2.580742611212938e-5 1.8562221982160063e-5 9.406504696630227e-5 -0.00016784276803560397 6.329821636131922e-5 -6.887701155688307e-5 3.9627703614990936e-5 8.382992828659995e-5 -0.00011808632834186293 1.5169622718018222e-6 2.411899086541505e-5 -2.8280484087415336e-5 -7.312638984564142e-6; 9.647063117668811e-5 9.926545018756312e-5 -6.150007977655717e-6 6.323678527920183e-5 6.924462351676177e-5 0.00010019879547838706 2.468007980283366e-5 -8.868626471880441e-5 6.640153712994004e-6 3.523226470642981e-5 -6.373057411328995e-5 -0.0001580529381338644 3.872784574710587e-5 -8.977380210340406e-5 -0.00017136945396701553 -0.00010658480030756347 0.00013598642246293654 7.036393500825918e-5 -4.7136701398222926e-5 7.323945894100405e-5 7.586678036799304e-6 -1.35399825761403e-5 4.087678527653995e-5 -1.9150623559103056e-5 3.81862125483601e-5 0.0001131195647009539 0.00016529229171069305 5.269456661847902e-5 4.4068922504864563e-5 0.0001160094804778378 2.734922483577737e-5 3.091402024225112e-5; -7.470070176612262e-5 -7.643999032994312e-5 9.485376886471453e-5 -0.00012061798970069141 -1.9938157670246367e-5 2.2046502458722106e-5 -0.0001517620714320163 -2.4230841695382026e-5 0.00018644409811270044 5.842631235756201e-5 4.314527636978139e-6 -0.0001807901737219017 5.7789546012937885e-5 -1.8214243204837307e-5 -4.115972314580484e-5 -0.00016013306066712816 -0.00012033321599563216 3.1328195792642924e-6 -0.00012138234359996505 2.0841514747301992e-5 8.672857602971702e-5 3.804754106125587e-5 -0.00014503677284199426 -9.733434325862056e-5 -8.591872228395172e-5 -2.356977001443633e-5 -0.00023840685225410627 -9.509397404584984e-5 9.226789682922486e-8 8.58047933903223e-6 8.639828393390295e-5 8.661220640644517e-6; -0.0001071416435336038 -0.00011218333646340028 -8.712317943921273e-6 -0.00015154297842330094 -2.8482666388501146e-5 -0.0001119664619947865 -7.868185174104438e-5 -3.2010010622093056e-5 8.20155520302913e-5 -1.2260246732466024e-6 5.26686391056763e-5 -5.8030915974291667e-5 7.153283978893481e-5 -1.634755870710854e-5 7.523875421228375e-5 -0.0001674174444671931 8.173213165334492e-5 8.891001663484885e-5 -6.630538235569424e-5 8.442475256358449e-5 9.397112909553096e-5 4.654276300769195e-5 -9.624612979886196e-6 -0.00015924206207653543 5.899358904187599e-5 8.532576349877801e-5 -2.818294059054117e-5 -4.684643034507304e-5 -1.3292263142128178e-5 -0.00012736782344673612 -0.00011598312529183294 -1.652481558651562e-5; 2.991939400059692e-5 0.00010960156723710154 -2.8031914917240012e-5 4.902715008908453e-5 -5.18743858268254e-5 -4.18413168520918e-5 2.518638357257026e-5 -7.50866405727108e-5 1.251484954961382e-5 8.280871188586492e-5 0.00013176449065182232 -0.00021883716035548114 -7.666057572381052e-5 2.8924455909569875e-5 -0.00031680596512190647 -3.1145524647700433e-6 -6.0548444162425064e-5 -2.6568958128639234e-5 -8.038570595120595e-5 4.4382655323413586e-5 -1.769893642661623e-5 0.00015025294633202676 -3.198261983955689e-5 0.000244718071916997 -1.954758809036634e-5 -9.921555283956147e-6 -0.00018200790727890604 3.2238110724997597e-5 6.310805256331857e-5 7.154224567480129e-6 0.00027442292557649183 3.798871111909441e-5], bias = [-2.5795167443868086e-9, -3.590213288192241e-10, -1.1610569473729257e-10, -3.076156884959727e-9, -1.7618166470596982e-9, 8.72586810442138e-10, 8.59803834023527e-10, -3.6102217008505316e-9, 2.0529385641692177e-9, 1.7789120422175976e-9, -1.305108146516895e-9, 1.5050404225464013e-9, 2.4940131391033954e-9, 1.957034186083882e-9, 2.477647609057324e-9, 1.4113279686275946e-9, -3.9623561538315764e-10, -3.5149813588346717e-9, -8.262626767979646e-10, 1.2058303039854427e-9, -2.7458700576940057e-9, 2.6917668814737094e-9, 1.8281039016745433e-9, -5.852163855293614e-10, -1.394829885872775e-10, 2.8294461170084398e-9, -2.1884796535763708e-10, -4.1026692676019053e-10, 2.7327747907050254e-9, -4.338466968897107e-9, -2.607416016590884e-9, 3.1795904561935447e-10]), layer_4 = (weight = [-0.0005662760310818158 -0.0006781675170934697 -0.0007267778196601953 -0.0007296059755809801 -0.0006984320014338072 -0.000575891399094857 -0.0008771606445582344 -0.00071096056514598 -0.0005720278920049972 -0.000742305901954489 -0.0007074876751000651 -0.00048543517816454886 -0.0007912103275237823 -0.0008867064643086312 -0.0008148756701667135 -0.0007400468797441288 -0.000683178168356445 -0.0008383598698343944 -0.0008635002388893792 -0.0007385790585355142 -0.0006841638166919786 -0.000742219307706576 -0.0008392881016257585 -0.0006384215658918247 -0.000715003410160227 -0.0006655036530496896 -0.0005331813105804819 -0.0005941503621289309 -0.0005110883010817317 -0.0006487476797265823 -0.0004272106040214516 -0.000775930327338248; 0.0003777717861803124 0.0002664882371319962 0.0002478050049219107 9.63354826134831e-5 0.00019310374330997512 0.00013686757168692703 0.00015665309733218518 0.00035533227461930145 0.00017453421209235096 0.00023997393881754623 0.00011275382588857367 0.0002012769521386296 0.00013896564664799983 0.00025807908128802355 0.0002093215063364491 0.00013127495556532702 0.00017538135141317458 8.509264795110928e-5 0.00037314961665168965 0.00016384347732990913 0.00028011519558398455 0.00017202671787090975 0.00025350625544537917 0.0005118773202486011 0.0003290430361469982 0.0002784049350238728 0.0004817543631940057 0.00015247324620534813 0.0002971851815675749 0.0003650750173913695 5.064274115750035e-5 0.0002048804632030237], bias = [-0.0006893403586862367, 0.00022435935938525207]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(compute_waveform(
    dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(ax, tsteps, waveform_nn_trained; marker=:circle,
        alpha=0.5, strokewidth=2, markersize=12)

    axislegend(ax, [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb)

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.10.6
Commit 67dffc4a8ae (2024-10-28 12:23 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
  JULIA_CPU_THREADS = 16
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 16
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.