Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux, ComponentArrays, LineSearches, OrdinaryDiffEqLowOrderRK, Optimization,
      OptimizationOptimJL, Printf, Random, SciMLSensitivity
using CairoMakie

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end
one2two (generic function with 1 method)

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end
soln2orbit (generic function with 2 methods)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end
d_dt (generic function with 1 method)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0)<1e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio1 "mass_ratio must be <= 1"
    @assert mass_ratio0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end
compute_waveform (generic function with 2 methods)

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defiing a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32),
    Dense(32 => 2; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32))
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[0.00015260985; -7.8647f-5; 0.000102452184; -0.00019427594; -0.00011503223; 2.5274567f-5; -0.00010334692; -4.256305f-5; 0.00015810276; -6.772859f-5; 8.890157f-5; 1.5272062f-5; -4.1950054f-5; 8.6136315f-5; -6.244482f-5; -1.4175181f-5; -5.20804f-5; -9.597999f-5; 3.4229946f-5; -0.00012799038; 0.00019974481; 0.00014724274; -4.6165682f-5; 9.398019f-6; 8.354277f-5; 6.936131f-5; 7.678188f-5; -0.00014963455; -0.00015557905; 0.00018619793; 0.00010090831; -9.528509f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[0.00010336576 0.00014304143 -6.098224f-5 0.0001265967 -1.4683301f-5 -3.0453868f-5 -2.0025649f-5 4.7287645f-6 -1.6355687f-5 5.1433203f-6 -0.0001068297 -8.302892f-5 6.010248f-5 5.9150185f-5 -5.2049286f-6 -4.0897554f-5 0.00011712862 0.00012961443 -2.4897265f-5 3.6077814f-5 -0.00013892032 9.5397314f-5 -0.00021568655 -9.0229136f-5 1.0694922f-5 -8.6894885f-5 7.089319f-5 0.00015622133 -8.582395f-5 -0.00013588274 4.973019f-5 -8.3185085f-5; -2.746344f-5 6.5333086f-5 -2.3266264f-5 0.00010730686 -2.5132324f-5 0.00013817326 -0.00025277166 8.8609406f-5 -0.00012462122 -5.865671f-5 -2.0224203f-5 -4.7625825f-5 -2.3824121f-5 -4.3506552f-5 -5.0604417f-6 0.00015306397 9.749967f-5 9.1946415f-5 0.00013610683 5.0231087f-5 2.3453045f-5 0.00010960522 2.0615578f-5 -0.00014816127 -6.280242f-5 8.3167346f-5 -0.000120661956 2.6820064f-5 2.4995294f-5 4.1317755f-5 -0.00023093085 0.00017696185; 4.451612f-5 -0.00015296608 0.00018965591 -0.00034871692 -2.8812688f-5 2.0696107f-5 -0.00012388628 -1.3153993f-6 2.9777379f-5 -1.5032774f-5 2.0424222f-5 4.5064862f-5 -4.0102972f-5 -7.675573f-5 4.307936f-5 2.4864257f-5 -0.000110762485 6.097245f-5 0.00020824288 2.7961007f-5 2.3124865f-5 -5.362765f-5 -9.248707f-5 2.183818f-5 -0.00011357458 2.302923f-5 0.00019080541 0.00014296857 5.80424f-6 -0.00026206844 -2.4804749f-5 0.0001412071; 0.00010141598 5.3543346f-5 -0.00018893697 -1.8788396f-6 0.00019048256 2.5341644f-5 -0.00011857316 0.00011877262 -2.0674357f-5 7.7540644f-5 0.000103709644 -5.217867f-5 -5.513836f-5 -2.463093f-5 1.3988208f-6 0.000106251544 3.03999f-5 -0.0001273353 -2.3947472f-5 -2.7751852f-5 -8.860473f-6 -0.00016088519 -6.679357f-5 -5.551096f-6 8.63898f-5 -6.7108704f-5 -0.00017555173 3.10481f-5 9.4600924f-5 0.0001306459 1.321488f-5 -6.543699f-5; 3.3288197f-5 3.8171504f-5 0.00017684491 3.0029623f-5 9.1213f-6 -1.921464f-5 7.241037f-5 -4.6479418f-5 8.1599654f-5 4.4619555f-5 0.000100976096 5.7833848f-5 -0.00013668794 6.214834f-6 -0.0001077683 -6.9692076f-5 -2.0406558f-5 -2.353007f-5 7.045837f-5 -4.5222754f-5 6.766817f-5 -2.531281f-5 1.9711935f-5 -2.8020795f-5 -1.6339955f-5 -5.959211f-5 -0.00018661523 -3.828506f-5 0.00012743239 3.0526524f-5 -5.8722722f-5 6.0574606f-5; -6.697135f-5 6.4912256f-6 5.4817003f-5 -0.0001403284 3.4606175f-5 8.523172f-5 3.817927f-5 -0.00015685994 -0.00026720564 0.00015604017 8.61448f-5 0.0001266672 2.6552745f-5 -2.3528293f-5 0.00027174273 -2.4294839f-5 -1.4893094f-5 -8.366422f-5 1.0592838f-5 9.929952f-5 -1.1529163f-5 4.970707f-5 -7.99759f-5 -6.039097f-5 -2.4051865f-6 -3.05213f-5 0.0001803932 9.994975f-5 -0.00014679483 2.5736403f-5 6.7794834f-5 0.00015436107; -0.000116543546 3.7097256f-5 7.708593f-5 5.6223587f-5 -0.00019021955 -4.2017446f-6 0.00013327449 1.558171f-5 8.2043107f-7 9.624121f-6 -8.378659f-6 0.00019077846 -8.901667f-5 -1.57753f-5 -0.0001232794 0.00014974693 -0.00010503569 -9.112024f-5 -9.518666f-5 5.8415248f-5 0.00012128274 -9.3657654f-5 -5.143463f-5 -6.568752f-5 3.619359f-5 5.7061725f-6 8.927777f-5 -0.00018928655 -3.2527107f-5 7.1206145f-5 -1.8165065f-5 5.5876302f-5; -8.9551077f-7 4.1096355f-5 -0.00011536459 0.00011555127 0.0002488086 -3.282806f-5 9.326017f-5 2.9691535f-5 -3.6147553f-6 9.084059f-5 4.885324f-5 2.8100174f-5 -1.4230024f-5 -0.00010649405 2.7216916f-5 0.000109900735 4.2984688f-5 7.108747f-5 -7.1960756f-5 -3.124629f-5 4.512778f-5 8.9931375f-5 -4.4618482f-5 -4.9339167f-5 -0.00012857803 3.0408488f-5 0.00015549271 -3.345534f-5 -3.8579048f-5 0.00017333613 -0.0001252742 4.8448208f-5; -0.00012982493 0.00012063537 0.00023806038 3.215419f-5 9.1727255f-5 0.00011761996 2.8813574f-5 -4.941275f-6 -0.00013736988 8.048708f-5 -1.434258f-5 6.188689f-5 -8.118794f-5 9.0774f-5 7.052984f-5 -9.2935195f-5 0.0001478184 0.0001469053 0.00017341968 -0.00029125644 -3.719563f-5 -0.000108039916 -0.00014066632 2.132447f-5 -9.073704f-5 7.622738f-6 0.00014544817 -0.0001513618 2.9640409f-5 7.276617f-5 -0.00013355179 -0.00011909324; -1.1431653f-6 -4.1830674f-5 0.00020059747 -3.805386f-5 7.063412f-5 0.00010768697 6.111935f-5 -6.386147f-5 0.00011432201 -4.2607375f-5 -4.7387617f-5 6.488285f-5 3.6462206f-5 0.00014978628 7.0789f-5 -2.6764974f-5 0.00012981684 -1.6343103f-5 -0.00013138352 1.9093968f-5 0.00012515156 2.7637327f-6 0.0001027212 -9.094625f-6 -0.00017861203 4.8180027f-6 6.138929f-5 3.881941f-5 3.4086406f-5 -3.8179893f-5 -5.7103993f-5 0.00012825224; -4.5868246f-6 -4.6386685f-5 3.2407792f-5 -4.8347454f-5 -0.00027806882 -0.00012535906 1.6981385f-5 0.00010004016 6.9769717f-6 -0.00014598218 -7.3471085f-5 5.6358393f-5 7.626752f-5 -0.00010403047 -2.7225085f-5 1.6116492f-5 -8.598341f-5 9.874891f-5 -6.34476f-5 8.958598f-5 -1.9507217f-5 1.0811174f-5 -3.366249f-5 6.8201625f-5 -6.296191f-5 8.759055f-5 8.3317034f-5 -1.9489737f-5 -1.619985f-5 4.307418f-5 3.4203556f-5 0.000120543205; -0.0001749017 6.0172737f-5 -0.00012345 -0.00013696341 6.4782303f-6 8.8225905f-5 -2.776452f-6 8.23779f-6 -3.1158386f-5 -3.912766f-5 0.00018139851 8.40538f-5 9.7273165f-5 9.248449f-5 -0.00018887423 -0.00013491655 8.995411f-5 9.1885595f-6 -3.334292f-5 -8.37572f-5 -4.4644974f-5 -0.00013083218 -9.4252086f-5 -1.3320466f-5 7.499567f-5 7.843602f-5 -3.8526552f-5 9.0716676f-5 -0.00023685067 -0.00016016095 0.00020340853 2.4760668f-5; 3.4392866f-5 0.00013635638 0.000108649685 -2.9796714f-5 -5.2138614f-5 -2.450751f-5 -9.8710254f-5 -9.191614f-5 -0.00011825382 -2.8982777f-5 -2.3153054f-5 -1.3614598f-5 -2.2564333f-5 -7.2488714f-5 -8.139346f-5 -0.00010835252 2.732413f-5 -8.7962064f-5 6.838455f-5 -5.1464114f-5 -0.00016326839 -4.7350935f-5 -0.0001499405 -0.00017136431 -0.000112860966 -8.390465f-5 -8.349454f-5 -3.15082f-5 8.266927f-5 8.434909f-5 3.299236f-5 7.7126424f-5; 0.0001115937 1.00625075f-5 -0.00011579857 -2.7800243f-5 0.00012086565 -5.1435378f-5 -2.5991572f-5 -5.5829452f-5 1.2546425f-5 0.00024656646 0.00010286773 0.00021392091 -5.359037f-5 3.4415818f-5 0.00017203705 8.841115f-5 3.2869215f-5 0.00010197107 -3.154579f-5 -0.00016466767 -0.00013835893 -0.0001847451 -0.00011711182 0.0001635272 -0.00021605098 -0.00016279025 0.0001187695 3.237641f-5 -0.00026390832 6.736794f-5 -9.311126f-5 7.579837f-5; -4.795925f-5 -0.00017195968 7.504384f-5 0.00010026909 -2.6507108f-5 9.349206f-5 -9.4892865f-5 7.0988135f-5 5.3035455f-5 0.00016342982 -8.978936f-5 0.00030916906 -7.628187f-5 -0.0001845066 8.2968116f-5 -7.434308f-5 -1.6004678f-5 3.0554387f-5 0.00018095326 -3.6349003f-5 -5.562586f-5 3.5274097f-5 -1.0249751f-5 -5.6561817f-5 -5.654012f-5 5.7649213f-5 -2.7833816f-5 -1.07598535f-5 -7.3985175f-5 0.00020482433 6.51229f-5 9.252003f-5; 2.7485773f-6 3.1959647f-5 -0.0002517222 6.43532f-5 -7.672241f-5 -7.21488f-5 -1.27722315f-5 -1.6988235f-5 -6.148495f-5 6.608065f-5 0.00016663592 0.000104211395 -3.574434f-5 -5.554403f-5 0.00015360562 4.0699186f-5 3.8885714f-6 -9.8087476f-5 -0.00013410936 -2.13802f-5 0.00010269968 -0.0002235178 -0.00012481824 -5.7375808f-5 -0.00011813085 -7.7346216f-5 0.00013995911 2.358551f-5 -7.799403f-5 -0.00013934179 -1.7993574f-5 3.130517f-5; 5.0395396f-5 -0.00014910333 -5.738833f-5 -4.432589f-5 -0.00021973813 0.00027061254 -4.433954f-5 -2.4373832f-5 6.584281f-6 5.943137f-5 2.3617204f-5 2.3323053f-5 -1.05054f-5 2.1509217f-5 -0.0001244022 1.3239537f-5 2.1059775f-5 -7.175686f-5 7.338638f-5 5.1005332f-5 9.4707735f-5 -0.00012272659 -6.934399f-5 5.7756293f-5 5.5485645f-5 -5.6874655f-6 0.00019415439 -4.5835848f-5 0.00016432274 -0.00014098905 -7.7937446f-5 -4.565583f-5; 9.548742f-5 8.8104855f-5 1.1489065f-5 1.70796f-5 4.05873f-5 3.7644033f-5 -4.5085984f-5 -4.1873107f-5 9.368923f-5 -0.00012526028 9.401317f-5 6.887658f-5 0.00011245381 1.2991295f-5 8.105124f-5 1.5037452f-6 2.1820753f-5 -3.5169913f-5 -4.6070123f-5 2.3281402f-5 -7.493207f-6 8.102095f-5 -0.0001753726 6.0631162f-5 2.652885f-5 3.05282f-5 -6.17771f-6 -8.106573f-5 7.286656f-5 -0.00011587083 -0.00010940261 -1.6667946f-5; -4.1796433f-5 0.0001201084 4.9772174f-5 -0.00013398791 7.090579f-5 6.660148f-5 9.9124845f-5 -0.00018039804 -9.279093f-5 5.040079f-5 -3.945905f-5 -0.00015805473 -0.00015954772 0.00011874505 -3.501961f-5 0.00010261248 0.00013617036 2.4169367f-5 -7.266364f-5 -2.5755524f-5 0.00012210949 0.00017489246 -6.0280323f-5 7.540128f-5 -4.2450763f-5 8.4778956f-5 -1.132527f-5 -7.896973f-5 0.00018433089 0.00019942445 4.5758377f-5 -0.00021893195; 8.150925f-5 -6.996695f-5 5.4715587f-5 -4.0337858f-5 0.00017565621 -8.767079f-5 6.563908f-5 -8.857945f-5 -1.3582419f-5 0.00019731445 -4.4211625f-5 0.00018605991 0.00017428536 0.00013893562 -9.5681906f-5 4.1383162f-5 -1.577123f-5 -1.9598332f-5 1.3721402f-5 -9.534453f-5 5.5300683f-5 -2.2394783f-5 1.092299f-5 5.759538f-5 -4.675608f-5 -2.2987653f-5 -9.4738716f-5 0.00010333382 4.3000582f-5 1.4843544f-5 -0.0001394841 0.00029875588; -3.3324995f-5 3.3323067f-5 -0.00016355225 0.00011155549 5.113697f-5 9.434331f-5 -4.9931663f-5 3.1934974f-6 -5.965116f-6 -9.7589626f-5 8.267505f-5 -8.809f-5 7.149281f-5 -0.00016669279 -3.6451795f-5 0.0001070702 3.589978f-5 0.00014488275 4.1999592f-5 8.542796f-5 0.0001568005 -0.0001314081 -9.93866f-5 -5.1894956f-5 0.00011380796 5.180706f-5 4.3366348f-5 -0.00013683659 2.342519f-5 -0.00012618814 7.1706076f-5 6.8396635f-5; -1.7590404f-5 3.5519492f-6 -7.4458105f-5 -7.8977326f-5 -0.00011279271 -4.0347754f-6 4.0752234f-6 -4.2879932f-5 7.543045f-5 3.670438f-5 8.244848f-5 -3.7206657f-6 9.032478f-5 4.778153f-5 -5.7927504f-5 -7.839252f-5 6.8207395f-5 0.00012987672 -0.00014525893 -3.7948954f-5 -2.6338223f-5 -0.00017794072 0.00014846133 -0.00012605857 4.76767f-5 2.557749f-5 0.000117504205 -9.039296f-6 -0.00014559209 -0.00013975445 5.3516895f-5 6.312331f-5; 7.0846385f-5 -0.00012753894 -1.8262132f-5 -9.6120835f-5 0.00024581162 6.7716064f-5 0.00010380628 2.1722106f-5 0.00021425949 2.5951364f-5 0.00016495884 -0.00021678908 -0.00013444107 0.00018361922 -2.689647f-5 -8.036979f-5 -0.000113886155 -5.6841374f-5 0.00013963396 -3.932045f-5 0.00015401756 -0.000103131075 -0.00019168004 8.281521f-6 6.7998066f-5 -0.00031813234 -1.2020308f-5 3.4881075f-5 -2.2972892f-5 -0.0001548569 -3.7985385f-5 3.7497408f-5; 4.9228085f-5 0.00015314888 -6.4044085f-5 -0.00015646934 1.5414771f-5 -3.639949f-5 -9.59889f-5 -0.00025412763 2.7047461f-5 -5.077092f-5 5.258801f-5 -2.1262389f-5 9.047698f-5 -1.6298243f-5 0.00014908965 0.00021567538 0.000119389624 6.346466f-6 -0.00019298447 0.000110181936 -8.582614f-6 4.0830815f-5 -4.724655f-5 -7.377032f-5 -6.2259f-5 3.5945166f-5 0.00014039181 8.325257f-5 4.052065f-5 -4.926663f-5 -1.8044051f-5 -6.651885f-6; 3.9613893f-5 -0.00012909873 -2.1522073f-5 -0.00025180468 -9.182792f-5 -0.0001475916 -0.00019842347 5.1894144f-6 7.1156435f-5 -0.00019479016 -3.2328113f-5 1.4162155f-5 -0.00010940932 4.7878988f-5 2.7114506f-6 0.00013948482 -8.20173f-5 -2.476321f-5 -3.6328016f-5 -4.597607f-6 -0.00013157725 0.00016876409 0.00011583638 -5.8664893f-5 0.00012847479 1.9677977f-6 -3.8417115f-6 -4.7470698f-5 4.850922f-5 -6.2099374f-5 6.083188f-5 5.4473076f-5; 1.6541006f-5 4.410444f-5 -0.00014626764 -0.00010445467 0.000109975874 7.847271f-5 3.0729174f-5 -1.913149f-5 6.585681f-5 -5.7936762f-5 0.00018681664 -4.3802273f-5 -0.00015770688 -3.928f-5 2.345331f-5 -7.938204f-5 9.057382f-5 2.730839f-5 -0.00014389955 -0.000114292096 3.4536333f-5 0.00014653074 0.00014470697 -5.8112873f-5 -9.641516f-5 -0.00021086939 0.00013280343 9.396289f-5 -4.315125f-5 1.3059146f-5 0.00027114674 0.0002226126; -4.7052818f-6 -4.4260705f-5 7.371006f-5 3.3023156f-5 8.573333f-5 1.3152049f-5 7.706682f-6 -1.5233942f-5 -0.00014491323 -9.739344f-6 0.00018602081 -0.00012195515 -6.3300184f-5 5.5421264f-5 4.150605f-5 0.00013925791 -0.00015974743 -4.9869004f-6 -4.77706f-5 0.00017404115 4.999979f-5 1.5544599f-6 0.000115629766 -9.2889575f-5 -2.558891f-5 1.6776352f-5 8.248136f-5 -2.9308982f-5 3.6977006f-5 -8.326385f-5 -5.0950434f-6 0.00018592757; 0.00012965428 0.00013400395 7.6611155f-5 -0.00014770107 0.00021079795 -0.00012886878 -0.00012528723 -1.633015f-8 3.3319295f-5 7.057257f-5 -5.5252818f-5 1.5773303f-5 -5.528159f-5 1.057125f-5 2.0634689f-5 3.5888657f-5 -3.5829304f-5 -9.539635f-5 8.475954f-6 5.3287775f-5 6.0401024f-5 4.029477f-5 0.00011422674 3.2576307f-5 -0.00026310928 2.779777f-5 -2.899275f-5 -0.00012657091 -0.00017139192 -2.9878349f-5 9.51809f-5 7.168047f-6; -7.434498f-5 -8.591183f-5 0.00011529131 0.00011616861 5.999175f-5 -8.550103f-6 1.4038615f-5 -0.00013765813 0.00022910953 -0.00010456774 0.00013956695 1.07146825f-5 6.7757935f-7 -2.8850767f-5 3.86379f-5 -5.0162267f-5 -7.004268f-5 4.4010627f-5 -2.151088f-5 -0.00022251051 -2.2048818f-5 -1.4044227f-8 -3.545067f-5 -3.7851027f-5 0.00018727871 0.000107721295 0.00013540454 0.00010730498 -0.00013982365 6.047344f-5 3.723915f-5 -0.00028035187; -0.00011000284 -3.0138096f-6 6.0499974f-6 -6.109422f-6 6.816686f-5 8.166895f-5 0.00010798777 0.00019587908 -1.4481197f-5 6.4603286f-5 -0.000105121726 -7.184211f-5 0.00013423387 -0.000115500225 1.4687183f-5 3.021017f-5 -0.00012694976 -1.2602762f-5 -6.703371f-5 -5.9768732f-5 -2.7392653f-5 1.1936158f-5 -7.2658055f-5 -4.8214126f-5 -0.00011574807 9.386914f-6 -5.768398f-6 3.4297824f-5 -0.00011155111 -0.00019593802 0.00011155002 7.951286f-5; 0.00011948413 0.00010091895 2.6238746f-5 0.000118959855 -6.7613146f-5 8.601657f-5 8.3948195f-5 -0.00019492878 0.00010483925 0.0001374937 -3.0986947f-5 -7.0423295f-5 -5.950575f-5 -0.000113730806 6.5442706f-5 -2.1962613f-5 -1.2937627f-5 7.2304965f-6 -2.1169253f-5 1.0575395f-5 -8.935209f-5 -0.00019780888 -0.00012746913 8.883125f-5 8.3176055f-5 0.00018069318 -0.00014460056 3.79141f-5 5.0987514f-5 0.000116375166 0.00013342658 -3.8849917f-5; -8.050089f-5 0.00023510997 7.476176f-5 -1.7655705f-6 -0.00013949386 -0.00017199643 -3.4259494f-5 -3.797127f-5 8.541953f-5 -1.9472469f-5 -0.00012058171 5.6471108f-5 -7.988716f-5 0.0003285953 -7.448997f-5 9.576991f-6 8.681409f-6 -2.9485195f-6 -8.315706f-5 1.8645953f-5 -0.00013632761 -3.0497815f-5 7.233996f-5 8.998996f-5 -4.2001728f-5 -2.1695625f-5 1.3934601f-5 0.00011437099 7.5558355f-6 2.5517962f-5 1.9873149f-5 -6.4184875f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[4.4616872f-6 -5.150322f-6 4.9941405f-5 -0.00012370726 -1.9322812f-5 8.383851f-5 -2.916828f-5 -7.848364f-5 -8.462423f-5 5.758053f-5 1.457892f-5 -0.00010173954 2.2763948f-5 -0.00011277133 9.739188f-5 0.00014119297 0.0001329381 7.339942f-5 -0.00010895764 -4.5828906f-6 3.4873166f-5 -3.8898197f-5 0.00014921636 8.310014f-5 0.00020034163 0.00022095068 -0.00016554163 2.3531824f-5 5.827376f-5 0.00030733025 5.9006084f-6 2.1624657f-5; -0.00014957588 -7.971886f-6 2.36406f-5 8.1495615f-5 0.00024106556 0.00020465541 -0.000109545275 -0.0002470758 -6.729019f-5 -0.00025968868 -0.0002119596 7.674428f-5 0.00014276577 -2.6860178f-5 -0.00010096451 2.8927703f-5 -0.00014684851 3.6662655f-5 -9.377643f-5 0.00010012105 -6.2467276f-5 0.0001868935 -9.728503f-5 -3.168174f-7 -2.2073727f-5 7.999627f-5 -0.00010033393 0.00014961357 -8.222044f-5 -0.00021848016 -6.119947f-5 -4.4171826f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most DL frameworks, Lux defaults to using Float32, however, in this case we need Float64

julia
const params = ComponentArray(ps |> f64)

const nn_model = StatefulLuxLayer{true}(nn, nothing, st)
StatefulLuxLayer{true}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),  # 64 parameters
        layer_3 = Dense(32 => 32, cos),  # 1_056 parameters
        layer_4 = Dense(32 => 2),       # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)

    axislegend(ax, [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"]; position=:lb)

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end
loss (generic function with 1 method)

Warmup the loss function

julia
loss(params)
0.0007422875324609472

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end
callback (generic function with 1 method)

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob, BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback, maxiters=1000)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [0.00015260984946504296; -7.86469972808941e-5; 0.0001024521843645468; -0.0001942759408845977; -0.0001150322277678096; 2.5274566723933263e-5; -0.00010334692342426808; -4.2563049646607875e-5; 0.00015810276090631297; -6.772859342153104e-5; 8.890157187124903e-5; 1.5272062228167252e-5; -4.195005385553845e-5; 8.613631507611514e-5; -6.244482210596956e-5; -1.417518069500993e-5; -5.2080398745552706e-5; -9.597998723612779e-5; 3.4229946322745406e-5; -0.00012799038086083766; 0.0001997448125618493; 0.00014724273933077475; -4.616568185152085e-5; 9.398018846688567e-6; 8.354277088079894e-5; 6.936131103424453e-5; 7.678187830603915e-5; -0.00014963455032551655; -0.00015557905135199756; 0.00018619792535863936; 0.00010090831347053447; -9.52850896281956e-5;;], bias = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = [0.00010336538710292511 0.0001430410530058281 -6.0982611986008304e-5 0.00012659632650796585 -1.4683674608720895e-5 -3.045424179687829e-5 -2.0026022840073143e-5 4.728390676273752e-6 -1.635606074229711e-5 5.142946546794512e-6 -0.00010683007255535682 -8.30292947018912e-5 6.010210657153093e-5 5.914981104909329e-5 -5.205302359871985e-6 -4.08979276205806e-5 0.00011712824655590406 0.00012961405903227533 -2.489763844166143e-5 3.607744022632014e-5 -0.0001389206981646241 9.539694022281156e-5 -0.00021568692588586544 -9.0229509597311e-5 1.06945481358992e-5 -8.689525836486517e-5 7.089281737465278e-5 0.00015622095545880565 -8.582432471554633e-5 -0.00013588311682794017 4.9729817900908315e-5 -8.318545858016333e-5; -2.7462306999894513e-5 6.533421958744734e-5 -2.3265130832381097e-5 0.00010730799670278042 -2.5131190224723537e-5 0.00013817439151177598 -0.0002527705236133743 8.861053914012758e-5 -0.0001246200862369178 -5.865557784266503e-5 -2.0223069324794666e-5 -4.762469131667667e-5 -2.382298759760426e-5 -4.3505418827345856e-5 -5.059308361460544e-6 0.0001530651044304899 9.750080619421192e-5 9.194754796050583e-5 0.0001361079613416868 5.0232220135571524e-5 2.3454178634033136e-5 0.00010960635529778522 2.0616711597604562e-5 -0.00014816013742522825 -6.280128752399058e-5 8.316847936664036e-5 -0.00012066082296944937 2.6821197141600455e-5 2.4996426989729448e-5 4.1318888295842976e-5 -0.0002309297141181858 0.000176962987214272; 4.451518476578962e-5 -0.00015296701598110587 0.00018965497589540822 -0.00034871785168198786 -2.881362410100401e-5 2.0695171099373504e-5 -0.0001238872107364128 -1.3163349762404382e-6 2.9776442954774508e-5 -1.503370980337755e-5 2.0423286753227765e-5 4.506392657510774e-5 -4.010390761088695e-5 -7.675666414563483e-5 4.307842324358363e-5 2.4863321652969575e-5 -0.00011076342046856033 6.0971512778067264e-5 0.00020824194804170036 2.7960071076977674e-5 2.312392940674583e-5 -5.3628584953383144e-5 -9.248800482730484e-5 2.1837243605273036e-5 -0.0001135755126028224 2.3028294219854088e-5 0.0001908044753350412 0.00014296762971078577 5.803304486637949e-6 -0.00026206937467342263 -2.4805684677104855e-5 0.00014120616402813345; 0.00010141584572997844 5.3543213308191706e-5 -0.00018893710728584877 -1.8789727429530088e-6 0.0001904824290439572 2.5341510646146333e-5 -0.0001185732929835943 0.00011877248982350215 -2.0674490205375718e-5 7.754051113883851e-5 0.00010370951133231534 -5.217880459818192e-5 -5.513849317062774e-5 -2.463106318750987e-5 1.3986876827833283e-6 0.00010625141078057562 3.0399767293463156e-5 -0.00012733543042409327 -2.3947605203968845e-5 -2.7751985161696252e-5 -8.860606048211497e-6 -0.0001608853220924646 -6.679370051565706e-5 -5.551229041114173e-6 8.638966903240317e-5 -6.710883679184244e-5 -0.00017555186275708596 3.104796780539641e-5 9.460079092682314e-5 0.00013064576853216643 1.321474666915278e-5 -6.543712004636508e-5; 3.3288934533141245e-5 3.8172240779629185e-5 0.00017684565190890345 3.0030360535837546e-5 9.122036858994367e-6 -1.9213902673555824e-5 7.241110378842338e-5 -4.647868027362285e-5 8.160039087257463e-5 4.4620292642746715e-5 0.0001009768335238432 5.783458490938674e-5 -0.00013668720694525412 6.215571007002974e-6 -0.00010776756007956636 -6.969133883515576e-5 -2.040582094608141e-5 -2.3529333161180455e-5 7.04591098796926e-5 -4.522201689227158e-5 6.76689038259231e-5 -2.5312071924906133e-5 1.971267247014893e-5 -2.8020057838631016e-5 -1.6339217294928616e-5 -5.959137106813808e-5 -0.00018661449339973154 -3.828432218555183e-5 0.0001274331249037119 3.052726114716255e-5 -5.8721985082714675e-5 6.0575343555198466e-5; -6.69689536280953e-5 6.493620081466852e-6 5.4819397324471316e-5 -0.00014032600230538738 3.460856968188304e-5 8.52341115987087e-5 3.818166513695232e-5 -0.00015685754552947732 -0.00026720324896276374 0.00015604256694672032 8.614719334757959e-5 0.00012666959881784152 2.655513961684522e-5 -2.3525898735959842e-5 0.00027174512141385515 -2.429244451747054e-5 -1.4890699470983604e-5 -8.366182666038194e-5 1.0595232912180065e-5 9.930191371124564e-5 -1.1526768697965159e-5 4.9709463188190895e-5 -7.997350550255204e-5 -6.038857392383042e-5 -2.402791973571001e-6 -3.0518905272461995e-5 0.00018039558998310238 9.995214421534754e-5 -0.00014679243687540591 2.5738797181390947e-5 6.779722824460715e-5 0.00015436346510426115; -0.00011654447964816528 3.7096322079135496e-5 7.708499967860419e-5 5.6222653687422115e-5 -0.0001902204844793819 -4.202678179598337e-6 0.0001332735532503053 1.5580776993130314e-5 8.194974486645739e-7 9.623187428344413e-6 -8.379592624138572e-6 0.0001907775271990713 -8.901760386494828e-5 -1.5776233818784507e-5 -0.00012328033565269628 0.00014974600114667826 -0.00010503662135876594 -9.112117052269779e-5 -9.518759409390318e-5 5.841431400198101e-5 0.00012128180739977739 -9.365858798106424e-5 -5.143556202371505e-5 -6.56884535840336e-5 3.619265541941177e-5 5.705238867206009e-6 8.927683436352977e-5 -0.00018928748843451232 -3.252804028876271e-5 7.120521101909867e-5 -1.8165998450299947e-5 5.587536859250944e-5; -8.919395014158924e-7 4.109992609932558e-5 -0.0001153610222473279 0.00011555484402097289 0.0002488121596315373 -3.282448720078722e-5 9.326374006099003e-5 2.96951062836783e-5 -3.6111840500613855e-6 9.084416401214626e-5 4.885680995988925e-5 2.8103745214094035e-5 -1.4226452870435863e-5 -0.00010649048031433162 2.7220487615473512e-5 0.00010990430623401201 4.298825901691751e-5 7.109103775923251e-5 -7.195718476673458e-5 -3.1242719481334025e-5 4.513135231436435e-5 8.993494579674242e-5 -4.4614910942120296e-5 -4.933559589794577e-5 -0.00012857445958894303 3.0412059138092944e-5 0.00015549628610364666 -3.3451769334604556e-5 -3.857547685864183e-5 0.00017333969903748372 -0.00012527062185932887 4.8451779227379115e-5; -0.00012982418661417012 0.00012063610961482778 0.00023806112335583937 3.215492802260515e-5 9.172799514195625e-5 0.00011762069812099503 2.881431396523299e-5 -4.940535517412668e-6 -0.0001373691363813503 8.04878193331785e-5 -1.4341839913720102e-5 6.188762677189958e-5 -8.118720393107245e-5 9.077474283109306e-5 7.053057675071967e-5 -9.293445577789964e-5 0.00014781913819312121 0.00014690603461637005 0.000173420424517794 -0.00029125569812898857 -3.71948906366432e-5 -0.00010803917643792459 -0.00014066558173802664 2.1325208991734075e-5 -9.073630167116513e-5 7.623477587152269e-6 0.00014544890768880883 -0.00015136106480789853 2.964114833153725e-5 7.27669078070499e-5 -0.00013355104851947347 -0.00011909249837921153; -1.138800892755361e-6 -4.182630920993612e-5 0.00020060183185406823 -3.804949421767475e-5 7.063848777462199e-5 0.00010769133102869966 6.112371800255798e-5 -6.385710851033969e-5 0.00011432637136448746 -4.260301040929029e-5 -4.7383252716173844e-5 6.4887214354467e-5 3.646657085185128e-5 0.00014979064751671005 7.079336380841091e-5 -2.676060951059271e-5 0.00012982119976798125 -1.6338738826820143e-5 -0.0001313791593231792 1.9098332519847383e-5 0.00012515592850528948 2.768097080385211e-6 0.00010272556271659346 -9.090260685542962e-6 -0.00017860766213132626 4.822367078962101e-6 6.139365603005108e-5 3.882377375402503e-5 3.409077060195372e-5 -3.817552835546321e-5 -5.709962835979739e-5 0.00012825660694644854; -4.587936542427796e-6 -4.6387797352421964e-5 3.240667995132485e-5 -4.83485660660262e-5 -0.00027806993473486897 -0.00012536017181983598 1.698027283427757e-5 0.00010003904883350836 6.97585976386373e-6 -0.00014598329160062833 -7.347219668727484e-5 5.6357280686000114e-5 7.626640710813582e-5 -0.00010403158246472612 -2.7226197358600636e-5 1.611537975267573e-5 -8.598452275319332e-5 9.874779918763718e-5 -6.344871378669878e-5 8.958486562196362e-5 -1.9508329001738335e-5 1.0810062507609082e-5 -3.366360267673947e-5 6.820051332769512e-5 -6.296302178807705e-5 8.758944152222716e-5 8.331592236945218e-5 -1.9490848513576697e-5 -1.6200962432683277e-5 4.307306646670151e-5 3.4202444498166114e-5 0.00012054209348579796; -0.00017490410778554278 6.0170324781205315e-5 -0.0001234524063474987 -0.00013696582414557114 6.475818048784202e-6 8.822349240653122e-5 -2.77886420554055e-6 8.235377885275522e-6 -3.1160798476766666e-5 -3.913007117808571e-5 0.00018139610232901322 8.405138555098197e-5 9.727075279805808e-5 9.24820812942878e-5 -0.00018887663850910793 -0.00013491896630149467 8.995169968692658e-5 9.186147275610295e-6 -3.334533199086836e-5 -8.3759612840365e-5 -4.4647386181146775e-5 -0.00013083459294284015 -9.425449775882038e-5 -1.3322877981419667e-5 7.499325487878592e-5 7.843361140111706e-5 -3.852896430181009e-5 9.07142637006546e-5 -0.0002368530806561378 -0.00016016336690516346 0.0002034061214944788 2.475825544388156e-5; 3.438724953572082e-5 0.00013635076239709143 0.00010864406859739893 -2.9802330835273328e-5 -5.214423007442824e-5 -2.4513126198405046e-5 -9.871587034309968e-5 -9.19217557459746e-5 -0.00011825943446477737 -2.898839310389912e-5 -2.315867030872309e-5 -1.3620214566151556e-5 -2.2569948931308802e-5 -7.249433045969715e-5 -8.139907666944065e-5 -0.00010835813672046039 2.7318513013340123e-5 -8.79676802362784e-5 6.837893371240557e-5 -5.146973061369349e-5 -0.00016327400161408678 -4.7356551738851904e-5 -0.00014994612073522494 -0.00017136993054763297 -0.0001128665819926698 -8.391026430061014e-5 -8.350015494968463e-5 -3.151381614610905e-5 8.266365394935001e-5 8.434347611162903e-5 3.298674411028774e-5 7.712080761096384e-5; 0.00011159371387764226 1.006252386575921e-5 -0.00011579855163729397 -2.7800226239218713e-5 0.00012086566311625184 -5.143536146062998e-5 -2.5991555962559026e-5 -5.58294359551148e-5 1.2546441207792343e-5 0.0002465664747389162 0.00010286774522244353 0.00021392092639432647 -5.3590354586798834e-5 3.441583430246806e-5 0.00017203706556317752 8.841116686669537e-5 3.2869231286145604e-5 0.00010197108530990201 -3.154577462266772e-5 -0.00016466765192019406 -0.00013835890968477946 -0.00018474508943183932 -0.00011711180377899312 0.00016352722195295044 -0.00021605095935667512 -0.00016279023657700623 0.00011876951793526655 3.237642703895804e-5 -0.0002639083086294926 6.73679551211313e-5 -9.311124268135633e-5 7.579838892995733e-5; -4.795660169972222e-5 -0.0001719570293038945 7.50464885664273e-5 0.00010027173884965963 -2.650446112734896e-5 9.349470820102847e-5 -9.489021722170463e-5 7.09907824807977e-5 5.303810285343525e-5 0.00016343247152922336 -8.978671503196168e-5 0.0003091717069788095 -7.62792198633598e-5 -0.00018450395254826762 8.297076310787069e-5 -7.434043181769504e-5 -1.600203104084998e-5 3.055703455406427e-5 0.00018095590239914093 -3.6346356134763675e-5 -5.562321275418329e-5 3.5276744531560955e-5 -1.0247103270964078e-5 -5.655917011376416e-5 -5.65374732081647e-5 5.7651860336182776e-5 -2.7831168333659767e-5 -1.0757206107396985e-5 -7.398252746267439e-5 0.00020482698055504641 6.512554573650575e-5 9.252267664355425e-5; 2.7446925802239038e-6 3.1955762694089994e-5 -0.0002517260867019936 6.434931188473198e-5 -7.672629666866444e-5 -7.215268792006588e-5 -1.2776116221531227e-5 -1.6992119838872637e-5 -6.148883343791648e-5 6.60767624654911e-5 0.00016663203192056694 0.00010420750973625586 -3.574822343889882e-5 -5.5547915820530565e-5 0.00015360173770804906 4.069530101023156e-5 3.884686663329883e-6 -9.809136039929693e-5 -0.00013411324411518276 -2.138408430563057e-5 0.00010269579857472327 -0.00022352168102168036 -0.00012482212272825862 -5.737969271959278e-5 -0.00011813473187739544 -7.735010089472732e-5 0.00013995522809327857 2.3581624788335302e-5 -7.799791578088117e-5 -0.00013934567627380345 -1.7997458911297022e-5 3.130128666871912e-5; 5.0395007968870356e-5 -0.0001491037207928141 -5.7388718011102045e-5 -4.432627861784642e-5 -0.0002197385135844927 0.00027061214972821647 -4.433992831433409e-5 -2.4374220407813373e-5 6.583892849118804e-6 5.94309833501343e-5 2.3616815491218386e-5 2.3322664895780596e-5 -1.0505787719344777e-5 2.150882868921591e-5 -0.00012440258684033833 1.323914897842214e-5 2.105938733041899e-5 -7.17572500037612e-5 7.338598992976952e-5 5.100494421970886e-5 9.470734698301335e-5 -0.0001227269774575353 -6.934437511187375e-5 5.775590511223879e-5 5.548525700181458e-5 -5.687853678287373e-6 0.00019415400060428028 -4.583623627364703e-5 0.00016432235610740214 -0.00014098944189423608 -7.793783402700203e-5 -4.5656218168337425e-5; 9.548879982532444e-5 8.810623488018586e-5 1.1490445768072769e-5 1.70809796421161e-5 4.0588680103905205e-5 3.7645413367800306e-5 -4.508460394360968e-5 -4.1871726598101374e-5 9.369061238459672e-5 -0.0001252589011252635 9.401455256950327e-5 6.887795982977943e-5 0.00011245519065626464 1.299267544739186e-5 8.105262325473529e-5 1.5051256062458692e-6 2.182213368571817e-5 -3.516853236730987e-5 -4.6068742694542165e-5 2.3282782176759216e-5 -7.491826426930215e-6 8.102233344317927e-5 -0.00017537121518186745 6.063254270914585e-5 2.653023123487156e-5 3.052958139080911e-5 -6.176329652291442e-6 -8.106434895096796e-5 7.286794219027672e-5 -0.00011586945244715632 -0.0001094012279915426 -1.66665652338213e-5; -4.17945316473193e-5 0.00012011030273818717 4.9774075229224144e-5 -0.00013398600778936242 7.090769438612135e-5 6.66033833807307e-5 9.91267465690991e-5 -0.00018039613998886548 -9.278902646081476e-5 5.040269250124168e-5 -3.945714845364715e-5 -0.00015805282375701025 -0.00015954582115561118 0.0001187469483523559 -3.501770836343538e-5 0.00010261438226401561 0.00013617226293384624 2.4171268229411192e-5 -7.266174136353962e-5 -2.5753622616025298e-5 0.00012211138753292182 0.0001748943563837397 -6.0278421664101965e-5 7.540318117784097e-5 -4.244886215353128e-5 8.478085735653924e-5 -1.1323368674918556e-5 -7.89678283796811e-5 0.00018433278680851775 0.00019942635342550826 4.576027865944482e-5 -0.00021893004807104237; 8.151342046003337e-5 -6.99627766010717e-5 5.471975747752404e-5 -4.0333687488256456e-5 0.00017566038297046926 -8.766662211597985e-5 6.564324723114634e-5 -8.857528008262877e-5 -1.3578248749037647e-5 0.00019731862413721452 -4.420745461573064e-5 0.0001860640855881103 0.00017428953434831218 0.00013893978904491188 -9.56777352115473e-5 4.1387332763075547e-5 -1.5767058707176754e-5 -1.9594161480543e-5 1.3725572549114705e-5 -9.534035633295265e-5 5.530485360902762e-5 -2.2390612124942647e-5 1.0927161034154067e-5 5.759955239740187e-5 -4.6751907925112574e-5 -2.298348261717446e-5 -9.473454555006589e-5 0.00010333799001782807 4.300475267482343e-5 1.4847714388073073e-5 -0.00013947992409555726 0.00029876005429798404; -3.332393272463517e-5 3.332413005959562e-5 -0.00016355118670067584 0.00011155655418954926 5.113803170938115e-5 9.43443742344673e-5 -4.993060021111582e-5 3.1945601095201734e-6 -5.964053356330796e-6 -9.758856277945077e-5 8.267611018349673e-5 -8.808893427517667e-5 7.149387282520075e-5 -0.00016669172283759995 -3.645073185357796e-5 0.00010707126101446868 3.590084132706257e-5 0.0001448838161068524 4.2000654944829006e-5 8.542901967852336e-5 0.00015680155819248467 -0.0001314070391471986 -9.938553513525524e-5 -5.1893893682020035e-5 0.00011380902331988618 5.180812193979877e-5 4.3367410840842446e-5 -0.0001368355292597345 2.342625201530438e-5 -0.00012618707797989003 7.170713841882714e-5 6.839769820184661e-5; -1.7591844504248322e-5 3.5505089850108183e-6 -7.445954505461047e-5 -7.897876602004603e-5 -0.0001127941501058137 -4.036215692813485e-6 4.073783123699237e-6 -4.288137241606797e-5 7.542900895128019e-5 3.6702939190528654e-5 8.244704114399016e-5 -3.7221060087727944e-6 9.032333802091688e-5 4.7780089893006525e-5 -5.7928944067006026e-5 -7.839396092682422e-5 6.82059548248987e-5 0.00012987527626459648 -0.0001452603751990501 -3.7950394095289105e-5 -2.63396628601545e-5 -0.00017794215781254864 0.00014845989100062106 -0.00012606001072204946 4.767526153367996e-5 2.557605004180171e-5 0.00011750276500019245 -9.040736209907509e-6 -0.00014559352674627728 -0.00013975589498517574 5.351545443185813e-5 6.312187216603351e-5; 7.084527516141823e-5 -0.00012754004654838108 -1.826324131528561e-5 -9.612194497733489e-5 0.0002458105116720303 6.771495446897502e-5 0.00010380516670280893 2.1720996131044488e-5 0.00021425837968578915 2.5950253814664436e-5 0.0001649577278048704 -0.0002167901877210531 -0.00013444218001226464 0.0001836181143634923 -2.689758019827262e-5 -8.037090100246108e-5 -0.00011388726505814586 -5.684248407036293e-5 0.0001396328551099583 -3.932155976784381e-5 0.00015401644833489285 -0.00010313218461945422 -0.00019168114903280996 8.280411578162084e-6 6.799695603418788e-5 -0.00031813345066169656 -1.2021417375797165e-5 3.487996565397153e-5 -2.2974001864922734e-5 -0.00015485800776061828 -3.7986494305219564e-5 3.749629814862042e-5; 4.922899763234392e-5 0.00015314979402367046 -6.40431723508995e-5 -0.0001564684250110854 1.5415683556532406e-5 -3.639857884943298e-5 -9.598798776958255e-5 -0.0002541267210103556 2.7048373631731683e-5 -5.0770006673223596e-5 5.258892208123547e-5 -2.1261476540051267e-5 9.047788928048838e-5 -1.6297330766497362e-5 0.00014909056637454866 0.0002156762887539362 0.00011939053667495153 6.3473783671033765e-6 -0.0001929835604455305 0.00011018284869398631 -8.581701585613896e-6 4.0831727638323946e-5 -4.724563830629781e-5 -7.376941055506265e-5 -6.225808890983467e-5 3.594607853349359e-5 0.0001403927265714857 8.32534818259952e-5 4.052156084118815e-5 -4.926571698841551e-5 -1.80431385883606e-5 -6.650972572493586e-6; 3.961003125949462e-5 -0.0001291025892475216 -2.1525935202255597e-5 -0.00025180854395657175 -9.183178013120497e-5 -0.0001475954687099571 -0.00019842733131884645 5.185552660218556e-6 7.115257353360682e-5 -0.00019479402367655082 -3.2331974629083906e-5 1.4158293148605353e-5 -0.00010941317855544638 4.7875126207475594e-5 2.7075888705057864e-6 0.00013948096052113468 -8.202116264660042e-5 -2.476707236713496e-5 -3.633187775780774e-5 -4.601468893838046e-6 -0.00013158111510491976 0.00016876022886788897 0.00011583251939223876 -5.866875476717482e-5 0.00012847092325170008 1.9639359801288054e-6 -3.8455732952927095e-6 -4.747455937250932e-5 4.8505358742101154e-5 -6.210323595369014e-5 6.082801878289211e-5 5.4469213815869134e-5; 1.6543217326354054e-5 4.410665115862269e-5 -0.0001462654328759873 -0.00010445246169989328 0.00010997808550306197 7.847492071156911e-5 3.073138560572426e-5 -1.9129278542851398e-5 6.585902138899681e-5 -5.7934550740204294e-5 0.00018681885541005976 -4.380006083813876e-5 -0.0001577046643331696 -3.927778760847611e-5 2.3455522578994785e-5 -7.937983190352472e-5 9.057603148100337e-5 2.7310601409374613e-5 -0.00014389734150320742 -0.0001142898838123463 3.453854497923957e-5 0.00014653295494204297 0.0001447091779584013 -5.8110661656273497e-5 -9.641295054173969e-5 -0.00021086717623696447 0.0001328056385420183 9.396509977809945e-5 -4.3149037254669635e-5 1.3061357832286835e-5 0.00027114895038827147 0.00022261480609009874; -4.703030903981298e-6 -4.425845433134161e-5 7.371231443839248e-5 3.302540684241715e-5 8.573557973731205e-5 1.3154300027328117e-5 7.708932421604108e-6 -1.5231690817654043e-5 -0.00014491097423589844 -9.737093173227407e-6 0.00018602306484835523 -0.000121952897793288 -6.329793283357457e-5 5.542351467310784e-5 4.150830030566679e-5 0.00013926016511986905 -0.00015974518206776365 -4.984649571911786e-6 -4.7768350818568274e-5 0.00017404340436340897 5.000204222166662e-5 1.5567107632559965e-6 0.00011563201663675516 -9.288732432168955e-5 -2.5586659071181042e-5 1.677860276258613e-5 8.248361230953111e-5 -2.9306730689055777e-5 3.69792567974236e-5 -8.326159619603162e-5 -5.0927925846794745e-6 0.00018592981617557094; 0.00012965383391314936 0.00013400350323831224 7.661071319784595e-5 -0.00014770150849176543 0.0002107975105657209 -0.00012886922015855603 -0.00012528767367871846 -1.6772188437818254e-8 3.331885270494809e-5 7.057212676340255e-5 -5.525326016755255e-5 1.5772860909457666e-5 -5.528203294193874e-5 1.0570807692823477e-5 2.063424651463978e-5 3.588821525521982e-5 -3.582974570850398e-5 -9.539679561090117e-5 8.475511979888678e-6 5.328733275131969e-5 6.0400581763359415e-5 4.029432782158048e-5 0.00011422630084165808 3.257586465519326e-5 -0.00026310972136957993 2.77973285603757e-5 -2.899319226958764e-5 -0.0001265713563286719 -0.00017139236117759887 -2.9878790907561934e-5 9.518045587651053e-5 7.167604938552972e-6; -7.434454094128798e-5 -8.591139269502703e-5 0.00011529174732767115 0.0001161690459169974 5.999218844716468e-5 -8.549666041846119e-6 1.4039051869428463e-5 -0.00013765769765687906 0.0002291099670159244 -0.00010456730453971618 0.0001395673905548278 1.0715119763871053e-5 6.780166116990658e-7 -2.8850329933095506e-5 3.86383369672222e-5 -5.0161829882755536e-5 -7.004224537614551e-5 4.401106414060198e-5 -2.151044211733415e-5 -0.00022251007542455582 -2.2048380948597398e-5 -1.3606966267971792e-8 -3.5450232850609285e-5 -3.785058945742214e-5 0.0001872791479157855 0.00010772173188651132 0.0001354049752748213 0.00010730541614374044 -0.00013982321181776132 6.0473878669093365e-5 3.723958868557663e-5 -0.0002803514355278937; -0.00011000457288237736 -3.015544046063071e-6 6.048263009948329e-6 -6.111156374297349e-6 6.816512632875617e-5 8.166721546083818e-5 0.00010798603489755423 0.00019587734733002459 -1.4482930958097591e-5 6.46015514642222e-5 -0.00010512346033712306 -7.184384139029202e-5 0.00013423213113790574 -0.00011550195903757072 1.4685448156598377e-5 3.0208436335014303e-5 -0.0001269514932507438 -1.2604496071358101e-5 -6.703544476542461e-5 -5.977046657127237e-5 -2.7394386954889347e-5 1.193442404409564e-5 -7.26597891050183e-5 -4.821586025420478e-5 -0.0001157498072577405 9.38517951661389e-6 -5.77013224093863e-6 3.429608933155857e-5 -0.00011155284573097179 -0.00019593975138623075 0.00011154828553779722 7.95111236650744e-5; 0.00011948599623496664 0.00010092081558173671 2.624061036545199e-5 0.00011896171983311007 -6.761128114089657e-5 8.60184320270133e-5 8.39500591762351e-5 -0.00019492691879644572 0.00010484111609616231 0.00013749556293692685 -3.0985081851435395e-5 -7.042143059446892e-5 -5.950388529821808e-5 -0.0001137289416722015 6.544457047926553e-5 -2.1960748002347002e-5 -1.2935762045563752e-5 7.2323611827912546e-6 -2.1167388498033466e-5 1.0577259869705711e-5 -8.935022722950282e-5 -0.00019780701930654116 -0.00012746726659610465 8.883311440218764e-5 8.3177920002305e-5 0.00018069504787195705 -0.00014459869797359988 3.7915963018118224e-5 5.098937819343023e-5 0.0001163770304781223 0.00013342844814974713 -3.884805281199545e-5; -8.050078806800994e-5 0.00023511006972884584 7.476186099128615e-5 -1.7654690745157193e-6 -0.0001394937576386911 -0.00017199633341616173 -3.425939260151423e-5 -3.79711678087075e-5 8.541963277283823e-5 -1.9472367535770663e-5 -0.00012058160839470278 5.64712094947867e-5 -7.988706104325287e-5 0.0003285954148614947 -7.448987210037348e-5 9.577092418196504e-6 8.681510256876511e-6 -2.9484181376646206e-6 -8.315695643060063e-5 1.8646054262800884e-5 -0.00013632750826755827 -3.0497713420034755e-5 7.234005849940545e-5 8.999006192797646e-5 -4.2001626321423666e-5 -2.169552365185859e-5 1.393470254160294e-5 0.0001143710901252158 7.555936894360944e-6 2.5518063443170058e-5 1.9873250015818535e-5 -6.418477326033816e-5], bias = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = [-0.0007393356631356071 -0.0007489476562820767 -0.0006938559346828582 -0.0008675046094933364 -0.0007631201559884656 -0.0006599587624004613 -0.0007729656209622571 -0.0008222808107028937 -0.0008284215773113528 -0.0006862165638324004 -0.00072921841540258 -0.0008455368079369195 -0.0007210329561832059 -0.0008565686853852169 -0.000646405377807458 -0.0006026041779394167 -0.0006108592455005961 -0.0006703979093702509 -0.0008527549401775154 -0.0007483799951603054 -0.0007089241702278806 -0.0007826955192859409 -0.0005945809784738605 -0.0006606972028658619 -0.0005434555242138841 -0.0005228466087932467 -0.0009093389054288502 -0.0007202655256364431 -0.0006855235903350032 -0.0004364670638092245 -0.0007378966953914529 -0.0007221726950017147; 0.00011457627872669469 0.00025618026433141375 0.0002877927515707295 0.0003456477710447178 0.0005052177115525379 0.00046880754250406485 0.0001546068767481068 1.7076288534374535e-5 0.00019686196576909187 4.463388460121063e-6 5.21925475183991e-5 0.00034089640974708266 0.000406917767834216 0.00023729197844525634 0.00016318760983626576 0.0002930797871608378 0.00011730364230717557 0.00030081480256011157 0.00017037571091220742 0.0003642731163235259 0.00020168487516823456 0.00045104565266535076 0.000166867118494305 0.0002638353349721921 0.0002420783589374531 0.00034414840603261303 0.00016381820265108757 0.000413765721896608 0.00018193171714177828 4.5671986928144186e-5 0.00020295267002873603 0.00021998033064331635], bias = [0.0, 0.0]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(compute_waveform(
    dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(ax, tsteps, waveform_nn_trained; marker=:circle,
        alpha=0.5, strokewidth=2, markersize=12)

    axislegend(ax, [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb)

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.10.6
Commit 67dffc4a8ae (2024-10-28 12:23 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
  JULIA_CPU_THREADS = 16
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 16
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.