Training a Neural ODE to Model Gravitational Waveforms
This code is adapted from Astroinformatics/ScientificMachineLearning
The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl
Package Imports
using Lux,
ComponentArrays,
LineSearches,
OrdinaryDiffEqLowOrderRK,
Optimization,
OptimizationOptimJL,
Printf,
Random,
SciMLSensitivity
using CairoMakieDefine some Utility Functions
Tip
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector
function one2two(path, m₁, m₂)
M = m₁ + m₂
r₁ = m₂ / M .* path
r₂ = -m₁ / M .* path
return r₁, r₂
endNext we define a function to perform the change of variables:
@views function soln2orbit(soln, model_params=nothing)
@assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
if size(soln, 1) == 2
χ = soln[1, :]
ϕ = soln[2, :]
@assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
p, M, e = model_params
else
χ = soln[1, :]
ϕ = soln[2, :]
p = soln[3, :]
e = soln[4, :]
end
r = p ./ (1 .+ e .* cos.(χ))
x = r .* cos.(ϕ)
y = r .* sin.(ϕ)
orbit = vcat(x', y')
return orbit
endThis function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d_dt(v::AbstractVector, dt)
a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
b = (v[3:end] .- v[1:(end - 2)]) / 2
c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
return [a; b; c] / dt
endThis function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d2_dt2(v::AbstractVector, dt)
a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
return [a; b; c] / (dt^2)
endNow we define a function to compute the trace-free moment tensor from the orbit
function orbit2tensor(orbit, component, mass=1)
x = orbit[1, :]
y = orbit[2, :]
Ixx = x .^ 2
Iyy = y .^ 2
Ixy = x .* y
trace = Ixx .+ Iyy
if component[1] == 1 && component[2] == 1
tmp = Ixx .- trace ./ 3
elseif component[1] == 2 && component[2] == 2
tmp = Iyy .- trace ./ 3
else
tmp = Ixy
end
return mass .* tmp
end
function h_22_quadrupole_components(dt, orbit, component, mass=1)
mtensor = orbit2tensor(orbit, component, mass)
mtensor_ddot = d2_dt2(mtensor, dt)
return 2 * mtensor_ddot
end
function h_22_quadrupole(dt, orbit, mass=1)
h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
return h11, h12, h22
end
function h_22_strain_one_body(dt::T, orbit) where {T}
h11, h12, h22 = h_22_quadrupole(dt, orbit)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
h11 = h11_1 + h11_2
h12 = h12_1 + h12_2
h22 = h22_1 + h22_2
return h11, h12, h22
end
function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
# compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2
@assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"
h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
@assert mass_ratio ≤ 1 "mass_ratio must be <= 1"
@assert mass_ratio ≥ 0 "mass_ratio must be non-negative"
orbit = soln2orbit(soln, model_params)
if mass_ratio > 0
m₂ = inv(T(1) + mass_ratio)
m₁ = mass_ratio * m₂
orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
return waveform
endSimulating the True Model
RelativisticOrbitModel defines system of odes which describes motion of point like particle in Schwarzschild background, uses
where,
function RelativisticOrbitModel(u, (p, M, e), t)
χ, ϕ = u
numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
denom = sqrt((p - 2)^2 - 4 * e^2)
χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = numer / (M * (p^(3 / 2)) * denom)
return [χ̇, ϕ̇]
end
mass_ratio = 0.0 # test particle
u0 = Float64[π, 0.0] # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4) # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize) # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, eLet's simulate the true model and plot the results using OrdinaryDiffEq.jl
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)
axislegend(ax, [[l, s]], ["Waveform Data"])
fig
endDefining a Neural Network Model
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but in case you do use them, make sure to mark them as const.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,
const nn = Chain(
Base.Fix1(fast_activation, cos),
Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-0.00012044786; 0.00030709585; 5.7402325f-5; 0.00011184847; -6.883471f-5; 7.7050645f-5; 3.0438408f-5; -5.8436588f-5; -1.7628188f-5; -0.00017820473; 6.8499685f-5; 0.00020898171; 9.58901f-5; -9.312582f-5; 8.8136665f-5; -9.456396f-5; -2.034539f-5; 5.0153456f-5; 5.2622378f-5; -7.524623f-5; -2.3429244f-5; -9.0277754f-5; -0.00016811605; 2.9423069f-5; 0.00012099289; 2.1602971f-5; 7.310771f-5; 0.00016930196; -1.484413f-5; 8.3593f-5; 6.1754734f-5; -6.8553796f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[9.4796575f-5 1.1092546f-5 -2.4355564f-5 -7.8037294f-5 -9.967948f-5 -7.106903f-5 -4.9395003f-5 -0.00016805249 0.00016202926 -0.00012363539 2.5839981f-5 1.2228748f-5 7.3185765f-5 -2.612499f-5 2.3616236f-5 1.6182805f-5 -7.112131f-5 -5.5558403f-6 -0.00011036144 6.332822f-5 0.00022915516 2.7695121f-5 7.957916f-5 -9.582184f-6 0.00015391654 5.4069864f-5 -2.6462034f-5 -7.930216f-5 -5.4252836f-5 -9.077891f-5 0.000103421524 0.00010394546; -0.00012220247 0.00010225082 3.192603f-5 -0.000111156696 3.449885f-5 -2.6526666f-5 -7.831516f-5 1.716833f-7 5.0230745f-5 -0.00015993665 0.00012007536 7.4470154f-5 -1.3162433f-5 1.2738403f-5 -0.00015259402 6.6471875f-5 3.199567f-5 6.364548f-5 3.413355f-5 0.00020714986 -2.7715336f-5 -5.538788f-5 -8.4351785f-5 -5.8191657f-5 -2.590694f-5 5.0548533f-5 -6.4259955f-5 -8.987183f-5 3.8789625f-5 6.887738f-5 -5.5484627f-5 -0.00019479598; 9.4215604f-5 0.00019582619 2.1647202f-5 -6.068263f-5 0.00011328327 -9.803558f-5 -0.00010749639 0.000102164246 7.5252196f-5 -1.5049703f-6 1.5062693f-5 -3.9622315f-5 2.8346065f-5 3.6037327f-5 0.00010334588 8.987578f-5 0.00018330796 -8.837212f-5 -2.982158f-5 -1.1723234f-5 0.0001397271 0.00013390997 1.1927522f-5 -8.683966f-5 9.0785616f-5 5.6957364f-5 0.00015078201 1.2858339f-5 -0.00019527429 5.1439943f-5 -9.3216266f-5 -3.684181f-5; -0.00012463909 2.2059228f-5 1.5048505f-5 5.9537808f-5 9.855503f-5 4.172475f-5 5.422613f-5 1.9696381f-5 9.5981995f-6 9.984575f-6 5.423125f-5 5.976781f-5 -6.3821055f-5 -1.3154446f-5 0.000109067834 -5.6714f-5 -6.026079f-6 0.00017932816 0.00011907978 -0.00022919003 -4.1798543f-5 6.56276f-5 -8.739979f-5 -1.3078985f-6 7.1765853f-6 5.8344544f-5 -7.047919f-5 -0.000107616295 -0.00021593562 -7.329585f-5 0.00010252708 0.00025426634; -2.092879f-5 0.00010610468 8.071125f-5 -5.7278266f-6 2.0254594f-5 7.8648045f-5 6.850536f-5 7.12669f-6 -1.6478303f-5 -0.00015642251 -4.5253135f-5 -0.00015757262 -1.2614584f-5 -2.9468021f-5 -0.00012698102 0.00010910575 -9.8832556f-5 3.4885823f-5 3.1617892f-5 -7.816523f-5 1.9137624f-5 0.0001375127 -1.5709951f-5 -1.574857f-5 -2.679725f-5 0.00023836923 -5.755505f-5 -3.5828657f-6 -1.7919301f-6 0.00017512462 3.992041f-5 0.00012228955; 0.00016422778 0.00012097977 -9.4113646f-5 -8.573783f-5 -5.6602472f-5 -0.00011637877 4.349166f-5 3.1006213f-5 9.015308f-6 -8.8297f-5 6.572088f-5 -0.00015980116 3.5327004f-5 -1.8568548f-5 -3.8305483f-5 0.00010862015 0.000102489714 -0.00011451274 -8.429109f-5 7.512491f-5 1.4618499f-5 -4.1597064f-5 0.00019096755 -1.6315138f-5 0.00011622716 0.00013301345 3.2959557f-5 -1.90215f-5 1.0299681f-6 -4.6355828f-5 -2.787961f-6 6.0562885f-5; -0.00017921945 -8.4630796f-5 -0.00013366195 6.5193824f-5 6.544887f-6 7.9547106f-5 -1.7709712f-5 -1.15125285f-5 -0.0002172024 -9.626359f-6 1.01779015f-5 -9.724797f-6 0.00020380432 4.187985f-5 -2.6682617f-5 -1.7134997f-5 -0.00010807045 5.8100257f-5 6.439622f-5 9.734031f-6 -3.4731144f-7 1.4506851f-5 -8.3452316f-5 -5.9509246f-5 0.0001367691 -1.2715282f-5 1.9113475f-5 -3.209608f-5 2.6154645f-5 -1.6870905f-6 -0.00010711407 4.3062526f-5; 0.00023832372 -8.090494f-6 -7.2124094f-5 -4.8599904f-5 4.5770386f-5 6.36389f-5 -1.8422586f-6 6.443221f-5 -2.1571115f-5 0.00013791097 0.00010693355 4.1140243f-5 -1.2532616f-5 1.8340847f-6 0.0001299273 1.9697704f-6 1.3528513f-5 8.118175f-5 -3.5809262f-5 -0.00011301762 -2.0473637f-5 -0.00014840737 2.7489953f-6 -0.00011876255 0.0002498537 -0.00020081234 -4.296378f-5 -9.8519224f-5 -1.1409089f-5 -1.2845848f-5 -8.4055115f-5 -0.00014343117; -8.377047f-5 -0.00014294907 -3.1796382f-5 -0.00011412784 -1.2699557f-5 -1.1861272f-5 -2.1532578f-5 0.00010494725 4.4989127f-5 1.0884545f-5 1.6549684f-5 -4.1671297f-5 -9.44836f-5 4.8194248f-5 8.538596f-5 -0.00015519703 4.6112258f-5 0.00015967073 6.234892f-5 -8.1059035f-5 -8.1540114f-5 -9.1283335f-5 0.00013763561 -8.096388f-5 -3.2869866f-5 -0.00016377943 2.0196845f-5 2.0734788f-5 3.6566176f-5 4.806491f-5 -8.4643725f-5 -0.00030385255; 4.9586615f-5 -2.9252804f-5 -9.850272f-6 2.9969007f-5 -7.69332f-5 8.5718515f-5 -3.803991f-5 -2.5092972f-5 6.098103f-5 3.418176f-5 -7.3974785f-5 3.2242235f-5 -2.8085145f-5 6.516339f-5 -7.8025936f-5 4.9561397f-5 -3.1680946f-5 7.26571f-7 -0.00020182709 -0.000103431805 -9.244879f-5 -0.0002838276 0.00013634685 -9.011799f-5 0.00011032208 -0.00016747687 1.770892f-5 0.00012205457 -0.0001818606 -0.000102310725 -0.00010122816 5.099657f-5; -2.8647313f-5 0.00017739406 4.0871404f-5 0.00021700792 8.6909065f-5 8.231314f-5 0.0001960694 -3.0228792f-5 -4.2345753f-5 4.1486786f-5 -3.125059f-5 0.0001548802 -8.256483f-5 -8.555385f-5 8.55829f-6 7.9230695f-6 -0.00029826327 -0.00014797931 1.9270285f-5 5.274198f-5 -5.5096534f-5 3.1399963f-5 -0.00018584187 -4.743371f-5 -1.5284411f-5 3.4065175f-5 0.00013515388 -1.5519632f-5 0.00010257173 -1.9500048f-5 7.720989f-5 0.0001461757; 8.9236564f-5 -3.3286372f-7 6.691308f-5 -7.4565214f-5 -0.0001654869 0.00013563299 1.6626467f-5 -6.274158f-5 2.8983668f-5 7.6170196f-5 -9.02384f-5 -1.2299273f-5 0.000105446175 0.00018477542 9.567336f-5 5.8434263f-5 -1.5404052f-5 -1.9989066f-5 6.052493f-5 0.00012370608 6.986544f-5 8.720782f-5 6.304038f-5 -0.00013386348 0.00013527379 -0.00010999043 9.5945725f-5 -5.5173954f-5 4.7915753f-5 -7.726477f-5 3.8093494f-5 -1.9544417f-5; 4.919027f-5 -6.0184793f-5 -1.9201554f-5 -6.947959f-5 -5.1049716f-5 -0.000105190535 4.39898f-6 0.00013577804 1.1364066f-5 -3.4468736f-5 -4.2933585f-5 -0.00013568712 4.252199f-5 -6.7942037f-6 -0.00014049228 0.00021724259 0.00011103296 -9.3777264f-5 5.945993f-5 1.3897656f-5 -1.1906968f-5 2.5911017f-5 3.3369473f-5 7.304316f-5 5.1756917f-5 6.0254573f-5 0.00013129637 -2.5737245f-6 -0.00019444039 8.033945f-5 -4.724259f-5 -0.00019807337; 5.6622202f-6 0.00011183516 8.2903225f-6 0.00017488956 3.9717273f-5 -4.416988f-5 6.651576f-5 -0.0001769479 5.9007434f-6 0.00012934698 -7.42041f-5 -4.2002583f-5 -8.6930355f-5 1.6933353f-5 -9.504139f-5 2.7220474f-6 -5.2557334f-5 -0.00011746545 8.187962f-5 3.734891f-5 -0.00023039352 -2.2167753f-5 5.6220237f-5 -4.0844607f-5 2.437647f-6 7.513821f-5 4.768968f-5 -4.0378263f-6 7.8439356f-5 0.0002011961 -7.9444275f-5 7.17738f-5; 2.6961508f-5 0.000104198276 -0.00015558868 2.460823f-5 8.000095f-6 0.00015259116 0.00014627598 0.00015204803 0.00010738079 0.00014117647 8.0138685f-5 -5.4546723f-5 8.9747526f-5 2.98841f-5 -1.4073334f-5 6.119936f-5 -0.0001842493 -3.519749f-5 0.00019735552 4.419367f-5 -2.541061f-5 -9.413738f-5 3.8229635f-5 2.293918f-5 5.9128065f-6 0.00012099237 -2.750577f-5 -0.0001262486 -0.00014517918 -3.1080137f-5 -0.000119426404 6.6779558f-6; 4.4587363f-5 2.0440957f-5 -0.00029177335 8.026116f-5 2.9209486f-5 0.00014586496 9.301757f-5 -3.5119163f-5 -6.498966f-6 -1.644344f-5 -8.1018305f-5 -1.0320208f-5 -0.0001017516 4.531171f-5 -6.8315516f-5 0.00023407293 -6.8158424f-6 0.00011292848 8.459774f-5 -4.6568207f-6 6.0714683f-5 -8.575441f-5 -0.00014532046 -3.1251595f-5 -5.2216117f-5 8.571367f-6 -0.0001007579 -0.00021139444 -0.000100035235 -3.6355865f-5 -5.9637732f-5 -8.283411f-5; -6.392392f-5 -2.828481f-6 -0.00014325608 7.168697f-5 -0.00017632994 -0.00017605517 3.6593552f-5 -4.5984823f-5 9.391094f-7 3.2717122f-5 3.246303f-5 0.00010563696 -4.372685f-5 -7.591815f-5 -2.6965568f-5 5.5988683f-5 -6.638038f-7 -0.00010557409 -1.3496352f-5 -2.4699035f-5 -2.322737f-5 -7.980954f-6 -6.6117194f-5 0.00011970699 -6.242947f-6 -2.0563051f-5 0.000109648114 -4.0867602f-5 7.9868136f-5 -6.450744f-5 -2.0494874f-5 0.00014891684; -3.479725f-5 -4.926487f-5 -3.211917f-5 7.360601f-5 -0.000129714 0.00017413449 -3.9059755f-6 6.959881f-5 -0.00017769271 -0.0001744295 4.3302753f-5 -7.0011745f-5 -7.6032855f-5 0.00013797714 -0.0001390037 0.00011137242 5.950648f-5 -7.9128455f-5 0.00010529677 -0.0001343249 6.1613187f-6 -5.106914f-5 -8.331876f-5 -0.00028277803 1.3019094f-5 -0.00015658175 -0.00027611628 2.9994024f-5 -6.698321f-5 7.94584f-5 1.9143758f-5 -3.3471406f-5; -0.00011269364 -1.0026684f-7 -9.6454925f-5 -0.00010852306 0.00015217545 -1.1768892f-5 0.00011706602 -5.3811447f-5 -1.5723414f-5 -5.2059608f-5 2.0904692f-5 0.00014612774 -2.2755117f-5 -1.5431977f-5 -8.163029f-5 6.515105f-5 -2.5758394f-5 8.1487684f-5 -9.010032f-5 -4.0233084f-5 -2.901354f-5 9.757913f-6 0.00013343563 -6.794908f-5 -2.9368499f-5 -7.147925f-5 -0.0001264983 -0.00013147604 5.9087797f-6 6.739753f-5 -2.9027906f-5 7.47235f-5; -0.000117242686 -0.00010503218 9.430404f-6 7.029856f-5 -0.00016247669 -2.1389245f-5 6.0471724f-5 0.00017011762 -0.0001895218 0.0002608442 -6.0313632f-5 9.747821f-5 -4.7090278f-5 -7.3311436f-5 -3.544682f-5 2.5590628f-6 1.9235245f-5 6.033701f-5 -0.00011240623 -3.3132088f-5 -6.0751816f-5 5.5436894f-6 -0.00017415296 0.00012355001 -0.00013328386 6.886681f-5 -8.896369f-6 4.395646f-5 -0.00018920413 -0.00018593784 0.00016763287 -6.9581074f-5; -0.00015021108 -4.6943274f-5 6.0896036f-6 6.195464f-6 -8.8209425f-5 7.049611f-6 0.00017356874 5.1342795f-5 -0.00015446865 -0.00016845932 0.00010486016 -4.642073f-5 3.0881827f-5 0.00010388684 0.000207599 0.00020272323 0.00020123049 0.00012065892 5.669295f-5 6.349648f-5 -5.0728286f-6 -4.4787652f-5 0.00010177179 7.3153446f-6 -1.3742259f-6 5.001435f-5 -2.8834042f-5 -9.939668f-5 6.2929767f-6 -6.7683184f-5 0.0001701598 -9.373845f-5; -8.646982f-5 0.00018930726 -9.488893f-5 -4.288654f-5 5.7639147f-5 8.0012294f-5 -3.628507f-5 -9.276406f-5 -4.4993685f-5 6.2115694f-5 0.00020514867 0.00015640767 -4.66593f-5 9.789081f-5 8.453611f-6 0.00010214874 6.783316f-6 5.5480377f-5 -0.00014552222 -3.2638636f-5 2.1958556f-5 -9.599983f-5 4.5355882f-5 -9.3625036f-5 -7.397935f-6 1.9457426f-5 5.5335673f-5 -0.00011182005 1.2524755f-5 0.00015238427 -6.473792f-5 -9.336243f-5; 7.61536f-5 8.7537264f-5 0.000159032 -2.0015192f-5 0.0001920022 -9.2148824f-5 -0.00027107893 7.134775f-5 -5.964702f-5 -6.571445f-5 2.7963786f-5 -9.262696f-5 -2.5103594f-5 0.00010223104 5.1622614f-5 -2.2026301f-5 1.254287f-5 0.00014579088 2.36304f-5 7.303235f-5 7.281683f-5 4.0352123f-5 -3.9984137f-5 1.2523445f-5 -0.00015963228 0.00012911043 4.8932754f-5 0.00013561141 -0.00013185403 0.00012759509 8.5364074f-5 2.0973628f-5; -1.7181472f-5 0.00019318338 2.0593694f-5 0.0001161752 -0.00011930812 0.00013603903 -0.00010987557 -6.4540174f-5 1.47638475f-5 -3.1142154f-5 -3.624574f-5 -6.0115795f-5 -7.74782f-5 7.568695f-5 -0.00030369422 -9.913992f-6 8.140064f-7 -3.695471f-5 -0.00013045505 -4.530694f-5 -5.1929233f-5 -2.6289688f-5 -8.129863f-5 -7.988875f-5 -0.00011339903 2.6109321f-5 -5.692291f-5 3.05815f-5 2.5934985f-5 4.8892067f-5 1.0419741f-5 1.336733f-5; -2.1557169f-6 0.00010185733 3.7880352f-5 4.6719797f-5 -2.973251f-5 -0.00013843946 -5.426408f-6 8.202717f-5 1.0169656f-5 0.00010772212 -1.235484f-5 5.1513243f-6 7.812719f-5 -4.496978f-5 0.00015264601 2.2617805f-5 0.00020410333 -2.846844f-5 0.00012139527 0.0001086654 0.00019218314 4.8405655f-5 -2.9905848f-5 -3.442985f-5 0.0002102744 7.591319f-5 -1.6843209f-5 5.5400662f-5 -3.7750815f-5 9.070578f-5 9.854213f-5 -8.151482f-5; 3.2419605f-6 9.605828f-5 0.00022020529 4.7928505f-5 -0.00017803696 7.1042527f-6 -4.307741f-5 -2.8723398f-5 9.695373f-5 -1.8868292f-5 -5.3267745f-6 2.740418f-5 -7.256197f-5 -3.7782458f-5 -0.00021100524 0.00032424802 0.00015173941 -0.00015614518 -4.1467785f-5 -0.00011328176 5.6386067f-5 9.703396f-5 -7.542526f-5 -4.3395998f-5 6.7283736f-6 -0.000127924 1.1097856f-6 -7.311346f-5 2.825613f-5 -4.6871122f-5 -0.00012457886 -0.00019774986; -4.4950077f-5 -9.830651f-5 -6.931714f-5 -3.389492f-5 9.116958f-5 0.00013329582 -0.00017824746 -1.1541629f-5 -6.630197f-5 -4.640424f-5 -0.00014209193 4.698149f-5 7.6787015f-5 0.00017277346 1.3575845f-5 9.911704f-5 8.136254f-5 5.7232457f-5 8.8289766f-5 6.26845f-5 -1.1980716f-5 -8.48268f-5 3.9258906f-5 -8.3469524f-5 -0.0001209443 0.00016685402 9.892024f-5 -6.957656f-5 -9.466292f-5 6.427767f-5 5.694294f-5 -4.737066f-6; -3.1270545f-5 -8.123784f-5 2.3985206f-6 1.34730435f-5 -0.000108163935 0.00018529857 7.051571f-5 0.0001259554 -0.00018397134 -5.6940007f-5 0.0001730629 -1.01659325f-5 -0.00017976086 5.0634408f-5 0.00010284774 -9.760483f-5 4.6595393f-5 -0.00015115508 4.5906876f-5 5.4053955f-5 -7.19007f-5 5.9659018f-5 -0.00017484365 0.000100501755 -8.1142774f-5 3.779759f-5 0.00011467142 6.064038f-5 -0.00011474916 6.891283f-5 6.8103305f-5 1.0438638f-5; -7.2932185f-6 -4.908901f-5 -7.023493f-5 5.5237873f-5 0.000126791 6.6472974f-5 -8.562885f-5 -0.00012439773 -4.4160905f-5 1.1599765f-5 -5.892378f-5 -4.004236f-5 -2.0834117f-5 1.2902297f-5 0.00016958294 -4.0377425f-5 0.00014503967 -5.169101f-5 6.1955616f-5 -4.970561f-5 -1.6047354f-5 -0.0001427192 -7.312084f-5 1.4380274f-5 -1.0596746f-5 4.468191f-5 8.956387f-5 1.6505868f-5 -6.530687f-5 3.0997184f-5 -8.389718f-5 -0.00019730409; 1.6698988f-5 0.00010588058 1.2021213f-5 0.000101326266 -2.0888761f-5 0.0001045014 -2.2996626f-5 -6.107836f-5 -0.00024456676 3.670595f-5 0.00010476234 -2.1824667f-6 -2.3347462f-5 -0.00021627672 -0.00015973381 6.828117f-5 -9.111476f-5 -0.00019055382 1.3807509f-5 7.0542706f-6 4.65025f-5 -0.00024270917 2.11825f-5 3.6171543f-5 7.084577f-5 -6.640373f-5 -8.8160916f-5 2.574334f-5 8.661719f-6 9.9288794f-5 4.2471118f-5 3.6291076f-6; 9.246665f-5 4.9583887f-6 -0.00014873774 0.00013083834 -4.4991535f-5 -6.573202f-5 2.9811756f-5 3.3715987f-5 5.6227286f-6 -2.3887696f-5 1.3696698f-5 6.905904f-5 -4.474729f-6 -1.0248775f-5 -2.4751089f-5 -5.2762378f-5 -8.24235f-5 3.5932204f-5 1.0048753f-5 -0.000106482425 3.0960473f-5 6.443919f-5 8.048831f-5 -8.316454f-5 -0.00017525021 0.00013779434 -0.00011660439 -0.00011954991 6.5317734f-5 7.2293624f-5 8.251316f-5 8.861964f-5; 0.00014635787 -4.699467f-5 2.2389679f-5 -6.5145265f-5 -8.975016f-5 6.256606f-5 0.00012472954 -5.6452725f-5 -0.00016134759 0.000100864025 8.361088f-5 -0.00023614768 -0.00019685789 8.0380785f-5 3.1244188f-5 -7.4716925f-5 -1.5209846f-5 -9.512541f-5 -0.00012058842 -0.00018053362 0.00019496764 -2.0349049f-5 -8.62151f-5 -9.971143f-5 -9.4534764f-5 -8.782304f-5 -8.1334445f-5 -2.4485569f-6 0.00010602754 -3.8188842f-5 7.26672f-5 -1.555162f-6], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[1.47009405f-5 -9.363873f-5 0.00019452645 -0.00010539466 1.9783949f-5 -3.7555194f-5 -1.8680816f-5 -9.407527f-5 -0.00015401204 -0.00010836691 -0.00015805512 5.297026f-5 -7.349804f-5 -8.106547f-5 9.3420895f-6 8.187151f-5 -6.369891f-5 -0.00012658587 -2.7731998f-5 0.00024427054 -4.6426107f-5 6.286345f-6 4.942947f-5 7.391273f-5 -6.764785f-5 -0.00010658569 5.9423586f-5 5.098941f-5 -6.149197f-5 -0.00012376954 -4.0581588f-5 -0.0001330456; 4.2996024f-5 -0.00012058678 3.4387427f-5 -3.4204364f-5 7.026914f-5 -1.1387028f-5 1.5478607f-5 8.25651f-5 6.291439f-5 -7.054833f-5 0.000121685334 -5.6104825f-5 -1.2276271f-5 0.00014375702 -5.4637905f-5 6.110635f-5 -0.00013161478 -1.3225972f-5 -3.5095254f-5 -3.662709f-5 7.87421f-5 0.00010789388 -0.00019727991 -4.26787f-5 -1.08027125f-5 0.00014412122 0.00011471914 -7.330481f-5 -3.490829f-5 8.8984314f-5 -0.00014078821 0.00012341775], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))Similar to most deep learning frameworks, Lux defaults to using Float32. However, in this case we need Float64
const params = ComponentArray(f64(ps))
const nn_model = StatefulLuxLayer(nn, nothing, st)StatefulLuxLayer{Val{true}()}(
Chain(
layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
layer_2 = Dense(1 => 32, cos), # 64 parameters
layer_3 = Dense(32 => 32, cos), # 1_056 parameters
layer_4 = Dense(32 => 2), # 66 parameters
),
) # Total: 1_186 parameters,
# plus 0 states.Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses
where,
function ODE_model(u, nn_params, t)
χ, ϕ = u
p, M, e = ode_model_params
# In this example we know that `st` is am empty NamedTuple hence we can safely ignore
# it, however, in general, we should use `st` to store the state of the neural network.
y = 1 .+ nn_model([first(u)], nn_params)
numer = (1 + e * cos(χ))^2
denom = M * (p^(3 / 2))
χ̇ = (numer / denom) * y[1]
ϕ̇ = (numer / denom) * y[2]
return [χ̇, ϕ̇]
endLet us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
axislegend(
ax,
[[l1, s1], [l2, s2]],
["Waveform Data", "Waveform Neural Net (Untrained)"];
position=:lb,
)
fig
endSetting Up for Training the Neural Network
Next, we define the objective (loss) function to be minimized when training the neural differential equations.
const mseloss = MSELoss()
function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
return mseloss(pred_waveform, waveform)
endWarmup the loss function
loss(params)0.0006850909326021244Now let us define a callback function to store the loss over time
const losses = Float64[]
function callback(θ, l)
push!(losses, l)
@printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
return false
endTraining the Neural Network
Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
optprob,
BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
callback,
maxiters=1000,
)retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-0.00012044786126345915; 0.0003070958482562322; 5.7402325182868226e-5; 0.00011184847244298293; -6.883471360187633e-5; 7.705064490447721e-5; 3.043840843019917e-5; -5.8436588005828755e-5; -1.7628188288633957e-5; -0.0001782047329470323; 6.849968485766238e-5; 0.00020898171351289927; 9.589010005568473e-5; -9.312581823904908e-5; 8.813666499894947e-5; -9.45639621931916e-5; -2.0345389202693778e-5; 5.0153455958922514e-5; 5.262237755225584e-5; -7.524622924378074e-5; -2.342924381079921e-5; -9.027775377017654e-5; -0.00016811605019016791; 2.9423068554005987e-5; 0.00012099288869635167; 2.160297117365517e-5; 7.310770888567701e-5; 0.0001693019585218477; -1.4844130419082516e-5; 8.359299681609998e-5; 6.175473390606925e-5; -6.855379615437227e-5;;], bias = [-7.826422390878514e-17, 5.504749459140146e-16, 5.893507539549282e-17, -4.340301802950623e-17, -1.0878276279873525e-16, -1.1537042678348924e-16, 1.1559873129106244e-18, -3.6859474373687494e-17, -2.26927109857747e-17, -2.527918449895972e-17, 8.487545803641119e-18, 1.471280313282973e-17, 1.7467230462522043e-16, -1.0513802227010279e-17, 2.6356795299941604e-16, -1.5331616722658953e-17, -9.972130092617212e-18, 1.4087988786074686e-17, 7.808038252087731e-17, -1.6915004213949867e-16, -2.817601127363989e-17, -1.9795539351462997e-16, -1.37183865988913e-16, 2.1180001984904967e-17, 1.560796566391367e-16, 5.806034979059533e-17, 1.383536935223537e-16, -2.2349130102770797e-17, 4.7128078813688984e-18, 5.651200236630362e-17, 6.671746361097022e-17, -8.425400634997288e-17]), layer_3 = (weight = [9.479709860770115e-5 1.1093069768849598e-5 -2.4355040115008174e-5 -7.803676981111871e-5 -9.96789529393961e-5 -7.106850516848038e-5 -4.939447881170233e-5 -0.00016805196337500312 0.000162029784586228 -0.0001236348614851592 2.584050543973474e-5 1.2229272395952676e-5 7.318628940865569e-5 -2.612446569528713e-5 2.361675997107955e-5 1.6183328806252148e-5 -7.112078292393732e-5 -5.5553162395902e-6 -0.00011036091728001547 6.332874199742684e-5 0.00022915568861080657 2.7695645438686077e-5 7.957968064019308e-5 -9.581660198275093e-6 0.00015391706274258058 5.407038788562033e-5 -2.646150987984838e-5 -7.930163683468632e-5 -5.4252311930578645e-5 -9.07783831293301e-5 0.00010342204786738646 0.00010394598229922605; -0.00012220364994998027 0.000102249645281847 3.1924851940242644e-5 -0.0001111578731767369 3.449767420836414e-5 -2.6527843216786022e-5 -7.831633963785591e-5 1.7050634103181996e-7 5.022956786280949e-5 -0.00015993782842694057 0.00012007418438163275 7.446897682106367e-5 -1.3163609834622136e-5 1.2737225846036593e-5 -0.00015259519393718998 6.647069851471892e-5 3.199449376854569e-5 6.36443018477836e-5 3.413237384240689e-5 0.00020714868326042503 -2.771651277423841e-5 -5.5389058685863326e-5 -8.435296156373218e-5 -5.819283440057648e-5 -2.590811716497343e-5 5.054735622554543e-5 -6.426113220793481e-5 -8.987300505092695e-5 3.878844829461985e-5 6.887620286154336e-5 -5.548580345628197e-5 -0.00019479715963523214; 9.421877910194828e-5 0.00019582936072525691 2.165037692744876e-5 -6.0679453620238946e-5 0.00011328644466399192 -9.803240829800401e-5 -0.00010749321229244215 0.00010216742082544716 7.52553707365828e-5 -1.5017951205380822e-6 1.5065868117254986e-5 -3.96191397363348e-5 2.8349240144467794e-5 3.6040502143311714e-5 0.00010334905816985895 8.987895414744052e-5 0.00018331113385781136 -8.836894705612939e-5 -2.9818404821104435e-5 -1.1720058703339893e-5 0.00013973027232847584 0.00013391314421592841 1.1930697089551404e-5 -8.68364848634284e-5 9.078879081921548e-5 5.69605392024335e-5 0.00015078518770447847 1.2861513985112288e-5 -0.0001952711143735845 5.1443118697951826e-5 -9.321309046075606e-5 -3.6838636172013535e-5; -0.00012463821500355963 2.206010261790298e-5 1.5049379124634096e-5 5.953868216087514e-5 9.855590300352937e-5 4.1725624522172215e-5 5.422700390985968e-5 1.9697255391606224e-5 9.599073872182562e-6 9.985449049359593e-6 5.4232122546040175e-5 5.9768686094987024e-5 -6.382018060472297e-5 -1.3153571364143404e-5 0.00010906870860557608 -5.671312707059512e-5 -6.025204731171934e-6 0.00017932902968672792 0.00011908065183737995 -0.00022918915661799126 -4.179766862003321e-5 6.562847678500754e-5 -8.739891805968294e-5 -1.3070241826923397e-6 7.177459645700799e-6 5.834541783619486e-5 -7.047831483001214e-5 -0.00010761542094920821 -0.00021593474835904496 -7.329497626565768e-5 0.00010252795740447658 0.00025426721671314604; -2.0927359636408396e-5 0.00010610610910687504 8.071268233936694e-5 -5.726396247046192e-6 2.025602474634419e-5 7.86494754184998e-5 6.850679050433044e-5 7.128120379206407e-6 -1.6476872823544817e-5 -0.00015642107906286583 -4.525170448585062e-5 -0.0001575711896829453 -1.2613153793498386e-5 -2.946659083943648e-5 -0.0001269795859651619 0.00010910717968829129 -9.883112512295516e-5 3.4887253369674884e-5 3.1619322490882804e-5 -7.816379662370005e-5 1.9139054294215365e-5 0.0001375141334766541 -1.5708520785563612e-5 -1.574713974957757e-5 -2.6795819629916174e-5 0.0002383706639567445 -5.7553620771501724e-5 -3.581435302192758e-6 -1.7904997174572171e-6 0.0001751260467618728 3.992184132687208e-5 0.00012229098205455004; 0.00016422892692075395 0.00012098091862436482 -9.411249742102109e-5 -8.573668422600828e-5 -5.6601323406950616e-5 -0.0001163776189363967 4.3492806768471815e-5 3.1007361727959926e-5 9.016456503149792e-6 -8.829584952169459e-5 6.572202870897665e-5 -0.00015980001010890921 3.532815277735252e-5 -1.8567399123143054e-5 -3.830433455137002e-5 0.0001086213003568866 0.00010249086223352578 -0.00011451159046687387 -8.428994008995805e-5 7.51260584061561e-5 1.4619647814970114e-5 -4.159591593038703e-5 0.00019096869688522768 -1.6313989575292448e-5 0.00011622830676655562 0.00013301460309462498 3.296070532682976e-5 -1.9020351131495686e-5 1.0311165547589261e-6 -4.63546796103883e-5 -2.7868124866033275e-6 6.05640332346847e-5; -0.0001792206402726117 -8.46319834792324e-5 -0.00013366313389015069 6.519263668706381e-5 6.543699403482255e-6 7.954591827256092e-5 -1.771089948051849e-5 -1.1513716269474523e-5 -0.00021720358374041304 -9.627547039771272e-6 1.0176713733744323e-5 -9.72598437982366e-6 0.00020380312897286924 4.1878664053730326e-5 -2.6683805042193036e-5 -1.7136184321496878e-5 -0.00010807164100442806 5.809906914037007e-5 6.439503166147805e-5 9.732843029996735e-6 -3.4849916660474544e-7 1.4505662950615316e-5 -8.345350390044622e-5 -5.951043378190846e-5 0.0001367679124839959 -1.2716469352382955e-5 1.9112287266398084e-5 -3.2097266664268535e-5 2.615345700432745e-5 -1.6882782276292628e-6 -0.00010711526003178803 4.306133822210505e-5; 0.00023832365890921772 -8.090550050103263e-6 -7.212415026078278e-5 -4.859996066006159e-5 4.577033009336987e-5 6.363884508100997e-5 -1.8423148550883755e-6 6.443215729159957e-5 -2.1571171469253507e-5 0.00013791091818785074 0.0001069334922656762 4.11401871540936e-5 -1.25326722401896e-5 1.8340284848873413e-6 0.0001299272427284505 1.969714180455445e-6 1.3528456424688432e-5 8.118169548179677e-5 -3.5809318301466275e-5 -0.00011301767322510753 -2.047369330757113e-5 -0.00014840742901942025 2.7489390002360114e-6 -0.00011876260928619935 0.0002498536506616858 -0.0002008123973202547 -4.2963836545090125e-5 -9.851927994095019e-5 -1.1409145240753913e-5 -1.2845904034469643e-5 -8.405517096907177e-5 -0.0001434312269840975; -8.377320124679336e-5 -0.0001429517950480566 -3.179911151790222e-5 -0.00011413057006108846 -1.2702286646233732e-5 -1.1864000832135042e-5 -2.153530737530813e-5 0.00010494451920123622 4.498639749100936e-5 1.08818159179907e-5 1.6546954732914086e-5 -4.167402663992066e-5 -9.448632851387939e-5 4.819151866570434e-5 8.538323084678737e-5 -0.0001551997637460533 4.610952886025002e-5 0.0001596680004996597 6.234618843620437e-5 -8.106176471017249e-5 -8.154284375166927e-5 -9.128606404090029e-5 0.00013763287928045 -8.096660973649851e-5 -3.287259539035228e-5 -0.0001637821632056166 2.0194115798515488e-5 2.073205826777276e-5 3.6563447004589555e-5 4.80621812431531e-5 -8.464645440369958e-5 -0.0003038552757669803; 4.9583548694723056e-5 -2.9255870380135993e-5 -9.853338873442037e-6 2.996594079035578e-5 -7.693626290689753e-5 8.571544852769521e-5 -3.8042977088270515e-5 -2.5096038065172852e-5 6.097796389097305e-5 3.417869478313553e-5 -7.397785126501297e-5 3.223916822778619e-5 -2.808821196937089e-5 6.516032257080871e-5 -7.802900260118066e-5 4.9558330225626746e-5 -3.168401205018194e-5 7.235044779302874e-7 -0.00020183015627823567 -0.00010343487125625342 -9.245185689339433e-5 -0.00028383066425136485 0.00013634378077427808 -9.012105480339365e-5 0.00011031901188857376 -0.0001674799383761364 1.7705853985021767e-5 0.00012205150081741616 -0.00018186366256829403 -0.00010231379170705871 -0.00010123122380152037 5.099350195108378e-5; -2.8645497660470624e-5 0.00017739587587023385 4.087321962187132e-5 0.0002170097397892129 8.691088099921101e-5 8.231495688859282e-5 0.00019607122179776635 -3.0226976160620687e-5 -4.2343937603717096e-5 4.1488601926967965e-5 -3.12487752681693e-5 0.00015488201119193332 -8.256301567700047e-5 -8.555203727256762e-5 8.56010523319742e-6 7.92488502964805e-6 -0.00029826145670186493 -0.00014797749805539734 1.927210017937798e-5 5.274379531227909e-5 -5.509471862852344e-5 3.1401778177393224e-5 -0.00018584005352852427 -4.7431894710015264e-5 -1.528259577855146e-5 3.406699055530366e-5 0.00013515569684285613 -1.551781657428995e-5 0.0001025735439169618 -1.9498232869512084e-5 7.721170208141242e-5 0.00014617752116363473; 8.923925946009664e-5 -3.3016849632146705e-7 6.691577748870842e-5 -7.456251893882146e-5 -0.00016548420223397897 0.0001356356837554848 1.6629162410489398e-5 -6.273888408895228e-5 2.8986363241082703e-5 7.61728910497178e-5 -9.023570289127286e-5 -1.2296578069737981e-5 0.00010544887066362278 0.00018477811266211688 9.56760517788117e-5 5.843695856174326e-5 -1.5401356512930618e-5 -1.9986370325107464e-5 6.0527624946650485e-5 0.00012370877396324805 6.986813553386041e-5 8.721051875045465e-5 6.30430762889614e-5 -0.00013386078041502627 0.00013527648427998028 -0.00010998773775053531 9.594841997614313e-5 -5.517125879988989e-5 4.791844860893375e-5 -7.726207383666146e-5 3.809618959212994e-5 -1.954172200339139e-5; 4.91898701256827e-5 -6.0185191458012226e-5 -1.9201952298337515e-5 -6.947998720354087e-5 -5.105011389607352e-5 -0.00010519093287077505 4.398581651892048e-6 0.0001357776438198724 1.1363668065293843e-5 -3.446913417794179e-5 -4.293398326628631e-5 -0.00013568751985805326 4.2521593215477556e-5 -6.794601929888288e-6 -0.00014049267868178643 0.00021724219020323164 0.00011103256308340624 -9.377766203035417e-5 5.945953141415679e-5 1.3897257460465105e-5 -1.1907366413672737e-5 2.5910618362084733e-5 3.336907518312852e-5 7.304275838758028e-5 5.1756518743768024e-5 6.0254175124989525e-5 0.00013129597407167068 -2.5741226891282866e-6 -0.00019444079028109866 8.033905230329613e-5 -4.7242987264346195e-5 -0.0001980737632293512; 5.662736540454023e-6 0.0001118356737345006 8.290838806045322e-6 0.0001748900755694755 3.971778976052797e-5 -4.416936310802809e-5 6.65162730649752e-5 -0.0001769473822866864 5.9012597114486875e-6 0.00012934749945203115 -7.420358481680245e-5 -4.200206633754004e-5 -8.692983858849199e-5 1.693386882039915e-5 -9.504087240997023e-5 2.7225636779460463e-6 -5.2556817835226085e-5 -0.0001174649372194258 8.18801342313813e-5 3.7349425539344225e-5 -0.00023039300170124642 -2.2167236537586708e-5 5.6220753023134123e-5 -4.08440904072779e-5 2.438163320458636e-6 7.513872667119709e-5 4.7690194761583304e-5 -4.037310035293005e-6 7.843987229649785e-5 0.0002011966204656916 -7.944375859436767e-5 7.177431659444033e-5; 2.6963480798504755e-5 0.00010420024833657849 -0.0001555867123322252 2.461020327543546e-5 8.002066955011967e-6 0.00015259313719368666 0.0001462779552305441 0.0001520500015097047 0.00010738275947303995 0.00014117844026555646 8.014065739487503e-5 -5.454475043438797e-5 8.97494980229336e-5 2.9886072901058455e-5 -1.4071361271566627e-5 6.120133244919523e-5 -0.00018424732055493154 -3.519551622659556e-5 0.00019735749143663192 4.4195644171192814e-5 -2.5408638372681994e-5 -9.413540768648515e-5 3.8231607198450754e-5 2.2941153168431317e-5 5.914778888173525e-6 0.00012099434449126487 -2.7503797661413667e-5 -0.00012624663153177525 -0.0001451772035004243 -3.1078164590075025e-5 -0.00011942443173822685 6.679928140581771e-6; 4.458530292538831e-5 2.0438897076417852e-5 -0.0002917754109199708 8.025910119102108e-5 2.9207426445621047e-5 0.00014586290398653859 9.301550653105533e-5 -3.512122294125983e-5 -6.501025841890315e-6 -1.644550048295749e-5 -8.102036463754071e-5 -1.032226828610768e-5 -0.00010175365696852545 4.5309649971728026e-5 -6.831757582915559e-5 0.00023407086609438522 -6.817902436187601e-6 0.00011292641920078583 8.459568106843212e-5 -4.658880703233748e-6 6.071262304828407e-5 -8.575646734755074e-5 -0.00014532252044374228 -3.125365492963306e-5 -5.221817719359037e-5 8.569307278202768e-6 -0.00010075995760928615 -0.00021139649562029248 -0.00010003729494713364 -3.635792473289781e-5 -5.963979220325937e-5 -8.283616717165154e-5; -6.392519025516356e-5 -2.8297488309447895e-6 -0.00014325734991266513 7.168570119270384e-5 -0.00017633120477817082 -0.0001760564355148283 3.6592284238075525e-5 -4.598609120271284e-5 9.378415544672053e-7 3.271585411008565e-5 3.2461763118272024e-5 0.00010563569049433479 -4.372811781044927e-5 -7.591941720724938e-5 -2.696683622466337e-5 5.598741518647522e-5 -6.650716326037058e-7 -0.00010557535460370746 -1.3497619869077132e-5 -2.4700302668269954e-5 -2.3228638377396858e-5 -7.982222080816653e-6 -6.611846219433499e-5 0.0001197057190498967 -6.244214995008695e-6 -2.0564319123288184e-5 0.00010964684589148461 -4.086887019488038e-5 7.986686796962036e-5 -6.450870750390893e-5 -2.049614158145251e-5 0.00014891556748502824; -3.480119766454562e-5 -4.9268817285870596e-5 -3.212311775561354e-5 7.360206155435504e-5 -0.00012971794557662428 0.00017413054093348273 -3.909924082761892e-6 6.959486427811992e-5 -0.0001776966578224519 -0.00017443344904011105 4.329880403682838e-5 -7.001569382978925e-5 -7.603680341408689e-5 0.00013797319343874994 -0.0001390076554277661 0.00011136847252354822 5.950253262861331e-5 -7.913240323679552e-5 0.00010529282236101107 -0.00013432885088396818 6.157370087095256e-6 -5.107308742727389e-5 -8.332270726230436e-5 -0.00028278197486790137 1.3015145220653087e-5 -0.00015658569963554215 -0.00027612022449167116 2.9990075297356413e-5 -6.69871565084751e-5 7.945444895748592e-5 1.9139808962141595e-5 -3.347535447820772e-5; -0.00011269484638372237 -1.0147281862919741e-7 -9.645613134974853e-5 -0.00010852426747927321 0.0001521742397114057 -1.1770098344037731e-5 0.00011706481201205071 -5.38126532677212e-5 -1.57246195061261e-5 -5.2060813676985544e-5 2.0903486084228403e-5 0.00014612653650208788 -2.2756323223197912e-5 -1.5433182842876117e-5 -8.163149867808508e-5 6.514984307674674e-5 -2.5759600238724405e-5 8.14864779482941e-5 -9.010152824824934e-5 -4.023429027196167e-5 -2.9014745440998788e-5 9.756707195869622e-6 0.0001334344287994764 -6.795028274549505e-5 -2.9369704852190614e-5 -7.148045368861757e-5 -0.00012649950621320094 -0.0001314772507515288 5.907573732348291e-6 6.739632409395556e-5 -2.9029111819306716e-5 7.472229664077799e-5; -0.0001172450080147973 -0.0001050345032225728 9.4280816383897e-6 7.029624101873053e-5 -0.00016247901076982808 -2.1391567871402878e-5 6.04694017866031e-5 0.00017011529997856172 -0.00018952412357154717 0.0002608418639184076 -6.031595477085645e-5 9.747588627675575e-5 -4.709260029900569e-5 -7.331375819464003e-5 -3.544914362552433e-5 2.556740254508899e-6 1.9232922922154575e-5 6.033468743137767e-5 -0.00011240855529737886 -3.3134410487916456e-5 -6.075413840423732e-5 5.541366944303131e-6 -0.00017415527837415416 0.00012354768695243263 -0.0001332861807995503 6.886448535536015e-5 -8.898691349684164e-6 4.395413851766685e-5 -0.000189206455262112 -0.00018594016147377814 0.0001676305459014032 -6.958339655372401e-5; -0.00015020869373967796 -4.694088544046125e-5 6.091992545579744e-6 6.197852726638079e-6 -8.820703565999029e-5 7.052000128339733e-6 0.00017357112905319055 5.13451838198212e-5 -0.000154466263993352 -0.00016845692640330044 0.00010486255145241416 -4.641834071648639e-5 3.088421603041259e-5 0.00010388923205044652 0.00020760139399516065 0.00020272561472682793 0.00020123287926270868 0.00012066131115733724 5.669533733470424e-5 6.34988706589753e-5 -5.070439656090462e-6 -4.4785262943891675e-5 0.00010177417665560668 7.317733561823591e-6 -1.3718370017341567e-6 5.001673947862589e-5 -2.8831652671044243e-5 -9.939429342906084e-5 6.2953655653402856e-6 -6.768079525483108e-5 0.00017016218283987582 -9.373605767888107e-5; -8.646897430019799e-5 0.00018930811103477386 -9.488807944301551e-5 -4.288568987697653e-5 5.763999553443788e-5 8.001314319250995e-5 -3.628422080958233e-5 -9.276320878136941e-5 -4.4992836305871145e-5 6.2116543008472e-5 0.00020514951512729676 0.00015640851535414122 -4.665845216089127e-5 9.789165534591118e-5 8.454460216917363e-6 0.00010214958939747692 6.784165011661094e-6 5.5481226186267876e-5 -0.00014552137389266885 -3.263778733974423e-5 2.195940516237209e-5 -9.599897992750799e-5 4.535673116612453e-5 -9.362418739777027e-5 -7.397086105749482e-6 1.945827472361312e-5 5.533652194123623e-5 -0.0001118192037073969 1.252560406562455e-5 0.0001523851145203212 -6.473707133857582e-5 -9.336158353556022e-5; 7.615622055258601e-5 8.753988552140787e-5 0.0001590346237435845 -2.0012570507748252e-5 0.0001920048253169179 -9.214620329198638e-5 -0.00027107630555742206 7.135037051289345e-5 -5.964439876601768e-5 -6.571182709646938e-5 2.796640731053321e-5 -9.262433557067666e-5 -2.510097326644716e-5 0.00010223366009545071 5.1625234868923275e-5 -2.2023679753104213e-5 1.2545490813004076e-5 0.00014579350138324386 2.363302157850993e-5 7.303497297930346e-5 7.281945183881802e-5 4.0354743782743566e-5 -3.9981516215421716e-5 1.252606582516323e-5 -0.00015962966242564857 0.0001291130484105951 4.893537518460142e-5 0.00013561402922735284 -0.00013185140958280032 0.00012759771382211407 8.53666952253223e-5 2.0976249311434904e-5; -1.718416350959445e-5 0.00019318069135594956 2.059100225559281e-5 0.00011617250894013975 -0.00011931081061667325 0.00013603633888814207 -0.00010987825916564918 -6.454286594126981e-5 1.4761155769140977e-5 -3.1144845336154177e-5 -3.6248431199430826e-5 -6.011848708328693e-5 -7.7480893735272e-5 7.568425681492161e-5 -0.0003036969133903855 -9.91668329553109e-6 8.113146646641217e-7 -3.695740050935874e-5 -0.00013045774502862915 -4.530963232574468e-5 -5.1931925183831925e-5 -2.629238005502813e-5 -8.130132445548043e-5 -7.989144032073109e-5 -0.00011340172167583933 2.6106629236513477e-5 -5.6925600431361e-5 3.057880931702166e-5 2.5932293654098618e-5 4.8889375109492606e-5 1.0417048913574617e-5 1.3364638478003587e-5; -2.150803207153488e-6 0.00010186224482325686 3.7885266114878086e-5 4.6724710605012965e-5 -2.972759592494731e-5 -0.00013843454270073206 -5.421494174969942e-6 8.20320807312749e-5 1.017456962899976e-5 0.00010772703045826938 -1.2349926132906986e-5 5.156237980032137e-6 7.813210196644483e-5 -4.4964866342809644e-5 0.00015265092462385716 2.2622719020261328e-5 0.00020410824310120151 -2.8463525618737633e-5 0.00012140018462987196 0.0001086703147072443 0.00019218805629379985 4.841056817412028e-5 -2.990093470120641e-5 -3.44249359707639e-5 0.00021027931744781802 7.591810082403891e-5 -1.6838295219083256e-5 5.540557559893564e-5 -3.774590130460461e-5 9.07106901783573e-5 9.85470420479902e-5 -8.150990960103256e-5; 3.2404618025527024e-6 9.605677785776946e-5 0.0002202037869040074 4.7927005817900486e-5 -0.00017803846259879227 7.102754062785649e-6 -4.3078907600590044e-5 -2.872489660033835e-5 9.695222996130087e-5 -1.886979027287454e-5 -5.328273181234749e-6 2.7402680808596794e-5 -7.256346709139708e-5 -3.77839567761225e-5 -0.0002110067433193153 0.0003242465174310088 0.00015173791342136886 -0.00015614667774456015 -4.146928387739141e-5 -0.00011328326201576975 5.638456803468502e-5 9.703246194591634e-5 -7.542676013938208e-5 -4.3397496318663513e-5 6.726874909209749e-6 -0.00012792549370600404 1.108286875265301e-6 -7.311495557471215e-5 2.8254631780567217e-5 -4.687262100445277e-5 -0.00012458035764096632 -0.0001977513600820517; -4.494939982486728e-5 -9.830583418898948e-5 -6.931646182146446e-5 -3.389424434225801e-5 9.117025932119604e-5 0.0001332964968523884 -0.00017824678049600223 -1.1540951863018005e-5 -6.630129043423097e-5 -4.640356179970536e-5 -0.0001420912519906703 4.698216819724519e-5 7.678769200629834e-5 0.00017277414029286532 1.3576521479087714e-5 9.911771506359153e-5 8.136321433176757e-5 5.723313391264022e-5 8.82904425734603e-5 6.26851762999723e-5 -1.1980039534128714e-5 -8.482612590250258e-5 3.925958320087612e-5 -8.346884694127835e-5 -0.00012094362297738742 0.00016685469766435684 9.89209149620439e-5 -6.957588609020376e-5 -9.466224561855652e-5 6.427834908300406e-5 5.694361628321395e-5 -4.736388990742579e-6; -3.1270368770278546e-5 -8.12376657156799e-5 2.39869694472504e-6 1.347321988963392e-5 -0.0001081637584027377 0.00018529874972249288 7.05158835737494e-5 0.0001259555737599257 -0.00018397116042957702 -5.693983081779682e-5 0.00017306307739294046 -1.0165756127904917e-5 -0.00017976068017357144 5.063458405150925e-5 0.00010284791818070123 -9.760465588239131e-5 4.659556901263943e-5 -0.00015115490356744338 4.5907052419566734e-5 5.405413133506573e-5 -7.190052389684194e-5 5.9659194386982245e-5 -0.0001748434716025855 0.00010050193116716046 -8.114259805348844e-5 3.779776482791579e-5 0.00011467159668630831 6.06405573600881e-5 -0.000114748980256776 6.891300466325966e-5 6.810348161910581e-5 1.043881450844843e-5; -7.294586915030194e-6 -4.909037728144479e-5 -7.023629644473097e-5 5.5236504856202976e-5 0.00012678962900017436 6.647160568454407e-5 -8.563022186800025e-5 -0.00012439909972536387 -4.416227296397238e-5 1.1598396092367762e-5 -5.892514730740291e-5 -4.0043728070380985e-5 -2.0835485550288425e-5 1.2900928929501597e-5 0.00016958157299732038 -4.037879319448718e-5 0.00014503829822123837 -5.1692379777618254e-5 6.195424736424311e-5 -4.970697830949819e-5 -1.6048722166401772e-5 -0.00014272057217831524 -7.312221044493675e-5 1.4378905113603626e-5 -1.0598114052207434e-5 4.4680541911561495e-5 8.956250473338258e-5 1.650449973304473e-5 -6.53082394209426e-5 3.099581532942282e-5 -8.389855123178093e-5 -0.00019730546103610612; 1.669723353815131e-5 0.00010587882427620494 1.2019457632851928e-5 0.000101324511367136 -2.0890516291835776e-5 0.00010449964468251288 -2.299838122155034e-5 -6.108011582284076e-5 -0.000244568518222077 3.670419612435549e-5 0.00010476058235042962 -2.1842216238772546e-6 -2.3349216983765092e-5 -0.0002162784745206545 -0.00015973556726122794 6.827941836207654e-5 -9.111651304294367e-5 -0.0001905555778404221 1.3805754339847552e-5 7.052515617836602e-6 4.650074554490853e-5 -0.00024271092258742582 2.1180745922626974e-5 3.616978795151239e-5 7.084401152213909e-5 -6.640548378811227e-5 -8.816267070243795e-5 2.5741585365629093e-5 8.659964284451916e-6 9.928703950834824e-5 4.246936298557741e-5 3.62735265924364e-6; 9.246661526079399e-5 4.958351073886451e-6 -0.0001487377824920872 0.0001308382985920004 -4.4991572703718186e-5 -6.573205731880754e-5 2.9811718068533555e-5 3.371594963057992e-5 5.62269102404799e-6 -2.3887734041333776e-5 1.3696660662692528e-5 6.905900105296779e-5 -4.474766537593832e-6 -1.0248812340857854e-5 -2.4751126456674126e-5 -5.2762415488966665e-5 -8.242354064320124e-5 3.59321663020933e-5 1.0048715612658765e-5 -0.00010648246219686137 3.096043534272441e-5 6.443915359371442e-5 8.048827172418251e-5 -8.316457509833297e-5 -0.0001752502515406989 0.00013779430324605758 -0.00011660442518374785 -0.00011954995110469744 6.531769637179652e-5 7.229358621801348e-5 8.251311979650614e-5 8.861960546740427e-5; 0.00014635500728291592 -4.699753090709287e-5 2.2386819446906613e-5 -6.514812385757506e-5 -8.975301871869561e-5 6.256320220090713e-5 0.00012472668485387952 -5.645558458964313e-5 -0.0001613504498065662 0.00010086116552959086 8.360802193310261e-5 -0.00023615053915742714 -0.00019686074639066486 8.037792603390812e-5 3.124132880730059e-5 -7.471978434223137e-5 -1.521270561098142e-5 -9.512827064937039e-5 -0.00012059127739958498 -0.0001805364806470624 0.00019496477788451816 -2.03519081962434e-5 -8.621795833203447e-5 -9.97142921816765e-5 -9.453762296215482e-5 -8.782590130904937e-5 -8.133730417212264e-5 -2.4514160569907907e-6 0.00010602467981770364 -3.819170134120258e-5 7.266434139709985e-5 -1.5580212304510985e-6], bias = [5.24049572234236e-10, -1.1769558817903184e-9, 3.1752074090508935e-9, 8.743321319526684e-10, 1.4303997215932788e-9, 1.1484794866719943e-9, -1.1877247765332698e-9, -5.6256971351011473e-11, -2.729272530315554e-9, -3.066510335341068e-9, 1.8155686955997358e-9, 2.6952243475144214e-9, -3.9820226689655825e-10, 5.162940068943036e-10, 1.972387764063652e-9, -2.0600106207800806e-9, -1.2678295678863777e-9, -3.9485646261008145e-9, -1.2059802785065523e-9, -2.3224956492643143e-9, 2.388915338509676e-9, 8.488451136162375e-10, 2.62118651821747e-9, -2.6917336226758314e-9, 4.91364944748469e-9, -1.4986824109263079e-9, 6.768741599787449e-10, 1.7638034128399654e-10, -1.3684556556205602e-9, -1.7549365296779343e-9, -3.7598208557155506e-11, -2.8591868683133405e-9]), layer_4 = (weight = [-0.0006527789721448691 -0.0007611186125829486 -0.0004729532222017344 -0.0007728745608167576 -0.0006476959173953861 -0.0007050350785554131 -0.0006861606983450748 -0.000761555192191664 -0.0008214917553951547 -0.0007758465791963222 -0.000825534945128708 -0.0006145094763365598 -0.0007409779577241811 -0.0007485453798728068 -0.0006581377300567333 -0.0005856082992428374 -0.0007311787910604668 -0.0007940653683220581 -0.0006952118794450025 -0.000423209247376977 -0.0007139058736092275 -0.0006611935557337788 -0.0006180502793227267 -0.0005935670090812137 -0.0007351271279522762 -0.0007740655500161425 -0.0006080563217314802 -0.0006164905100711819 -0.0007289718400736683 -0.0007912493744156664 -0.000708061507850692 -0.0008005252921169238; 0.0002696136627627799 0.00010603084792296516 0.0002610049855982681 0.0001924132704576473 0.0002968867629828122 0.00021523060144378796 0.0002420962356082139 0.00030918274191383806 0.00028953196196782516 0.00015606922342341616 0.00034830294448360974 0.00017051275370702355 0.00021434136879020513 0.00037037465623458466 0.00017197970245732294 0.00028772395644660153 9.500284449988847e-5 0.00021339152572543755 0.00019152237446928098 0.00018999050762451416 0.00030535969115424885 0.00033451151540941215 2.9337676224411167e-5 0.00018393887873706938 0.00021581471121226138 0.0003707388430571308 0.0003413367807946356 0.00015331283391890576 0.0001917093335592534 0.00031560192680133044 8.582943324688896e-5 0.00035003531165538087], bias = [-0.0006674799197779137, 0.00022661764148690816]))Visualizing the Results
Let us now plot the loss over time
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")
lines!(ax, losses; linewidth=4, alpha=0.75)
scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)
fig
endFinally let us visualize the results
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
s3 = scatter!(
ax,
tsteps,
waveform_nn_trained;
marker=:circle,
alpha=0.5,
strokewidth=2,
markersize=12,
)
axislegend(
ax,
[[l1, s1], [l2, s2], [l3, s3]],
["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
position=:lb,
)
fig
endAppendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.11.8
Commit cf1da5e20e3 (2025-11-06 17:49 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 4 × AMD EPYC 7763 64-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 4 default, 0 interactive, 2 GC (on 4 virtual cores)
Environment:
JULIA_DEBUG = Literate
LD_LIBRARY_PATH =
JULIA_NUM_THREADS = 4
JULIA_CPU_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0This page was generated using Literate.jl.