Training a Neural ODE to Model Gravitational Waveforms
This code is adapted from Astroinformatics/ScientificMachineLearning
The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl
Package Imports
using Lux,
ComponentArrays,
LineSearches,
OrdinaryDiffEqLowOrderRK,
Optimization,
OptimizationOptimJL,
Printf,
Random,
SciMLSensitivity
using CairoMakieDefine some Utility Functions
Tip
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector
function one2two(path, m₁, m₂)
M = m₁ + m₂
r₁ = m₂ / M .* path
r₂ = -m₁ / M .* path
return r₁, r₂
endNext we define a function to perform the change of variables:
@views function soln2orbit(soln, model_params=nothing)
@assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
if size(soln, 1) == 2
χ = soln[1, :]
ϕ = soln[2, :]
@assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
p, M, e = model_params
else
χ = soln[1, :]
ϕ = soln[2, :]
p = soln[3, :]
e = soln[4, :]
end
r = p ./ (1 .+ e .* cos.(χ))
x = r .* cos.(ϕ)
y = r .* sin.(ϕ)
orbit = vcat(x', y')
return orbit
endThis function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d_dt(v::AbstractVector, dt)
a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
b = (v[3:end] .- v[1:(end - 2)]) / 2
c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
return [a; b; c] / dt
endThis function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d2_dt2(v::AbstractVector, dt)
a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
return [a; b; c] / (dt^2)
endNow we define a function to compute the trace-free moment tensor from the orbit
function orbit2tensor(orbit, component, mass=1)
x = orbit[1, :]
y = orbit[2, :]
Ixx = x .^ 2
Iyy = y .^ 2
Ixy = x .* y
trace = Ixx .+ Iyy
if component[1] == 1 && component[2] == 1
tmp = Ixx .- trace ./ 3
elseif component[1] == 2 && component[2] == 2
tmp = Iyy .- trace ./ 3
else
tmp = Ixy
end
return mass .* tmp
end
function h_22_quadrupole_components(dt, orbit, component, mass=1)
mtensor = orbit2tensor(orbit, component, mass)
mtensor_ddot = d2_dt2(mtensor, dt)
return 2 * mtensor_ddot
end
function h_22_quadrupole(dt, orbit, mass=1)
h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
return h11, h12, h22
end
function h_22_strain_one_body(dt::T, orbit) where {T}
h11, h12, h22 = h_22_quadrupole(dt, orbit)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
h11 = h11_1 + h11_2
h12 = h12_1 + h12_2
h22 = h22_1 + h22_2
return h11, h12, h22
end
function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
# compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2
@assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"
h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
@assert mass_ratio ≤ 1 "mass_ratio must be <= 1"
@assert mass_ratio ≥ 0 "mass_ratio must be non-negative"
orbit = soln2orbit(soln, model_params)
if mass_ratio > 0
m₂ = inv(T(1) + mass_ratio)
m₁ = mass_ratio * m₂
orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
return waveform
endSimulating the True Model
RelativisticOrbitModel defines system of odes which describes motion of point like particle in Schwarzschild background, uses
where,
function RelativisticOrbitModel(u, (p, M, e), t)
χ, ϕ = u
numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
denom = sqrt((p - 2)^2 - 4 * e^2)
χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = numer / (M * (p^(3 / 2)) * denom)
return [χ̇, ϕ̇]
end
mass_ratio = 0.0 # test particle
u0 = Float64[π, 0.0] # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4) # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize) # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, eLet's simulate the true model and plot the results using OrdinaryDiffEq.jl
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)
axislegend(ax, [[l, s]], ["Waveform Data"])
fig
endDefining a Neural Network Model
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but in case you do use them, make sure to mark them as const.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,
const nn = Chain(
Base.Fix1(fast_activation, cos),
Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-9.19909f-5; -4.4559576f-5; -2.7491587f-5; 6.290054f-5; -3.621327f-5; -0.00013411113; 7.2571354f-5; 1.4270464f-6; -0.00012965755; -5.7273784f-5; -3.5657078f-5; -0.00020835806; 4.5357607f-5; 2.928108f-5; 0.00022058528; 8.454695f-5; -2.9000947f-5; -0.0001008923; -9.882358f-5; -9.148982f-5; -2.8624449f-5; -0.00021886564; -1.8664667f-5; 6.0547653f-5; 0.00011555569; -5.002942f-5; 1.4415856f-5; -7.3850664f-5; -0.00018664483; -0.0001541562; 8.6016364f-5; -2.432015f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[8.833475f-6 0.00020744317 -8.830723f-5 2.146281f-5 -7.643282f-5 -0.00010209084 -1.0919909f-5 4.0494342f-5 -7.380818f-5 0.0001456649 -0.00020632187 5.3508433f-5 0.00011588859 -8.592688f-5 4.517313f-5 -1.832538f-5 0.00011615711 -4.5780846f-5 -6.683449f-5 -0.00013616258 -0.00014403729 0.0001492605 -4.0528936f-5 0.00012162637 0.00016580566 -0.00011682067 -4.6523575f-5 -0.00013942843 -9.6833646f-5 -6.10436f-5 -4.5118715f-5 -0.00016454868; 6.880413f-5 4.005147f-5 -6.992413f-6 6.975775f-5 -8.545974f-5 5.734624f-5 -0.0001914766 1.7522703f-5 -1.0982287f-5 -1.453377f-5 -3.507752f-5 -3.612459f-5 1.6396341f-5 -0.00019153468 -0.00012208513 -8.876658f-5 -2.2259574f-5 0.00011609591 6.707943f-5 -0.000103804094 5.2171663f-6 -3.1476266f-5 0.000101398546 -3.863537f-5 -7.8566605f-5 1.50819f-5 0.00021030905 0.00012278801 -4.729031f-5 -8.614738f-5 -2.8061057f-5 -0.00012731178; -0.00019290065 -4.4169105f-6 1.7916791f-5 5.1484432f-5 -0.00019748231 7.896957f-5 4.502991f-5 0.000103678474 -0.00019733874 0.00017665658 -0.00010151908 -0.00018269793 -5.9817157f-5 5.0566767f-5 -9.560002f-5 0.000116864176 -3.5632864f-5 6.8938066f-6 -5.3216478f-5 0.000119835524 6.5105676f-5 0.00013856294 -2.5442218f-6 0.00014226997 1.9917097f-5 -5.241332f-5 -1.8178598f-5 -4.1441293f-5 6.486039f-5 2.1808864f-5 -6.7455236f-5 8.6028616f-5; -1.6836055f-5 -6.0888284f-5 0.00020718256 6.87419f-5 8.075113f-6 6.6913776f-7 1.1175515f-5 7.922929f-5 -0.00014010767 -2.9743711f-5 -0.00023777426 8.887239f-5 -8.969297f-5 -4.7542515f-5 -9.466747f-5 -9.275451f-6 9.422805f-5 -9.0725254f-5 0.0001449266 -0.00021852959 6.302275f-5 -0.00014847756 -6.538489f-5 9.376784f-6 -0.00013655123 -0.00012609246 4.2278745f-5 8.8182685f-5 -0.00017586748 0.00011295892 3.65648f-5 5.211355f-5; -3.8736278f-5 6.521176f-5 -3.0001072f-5 -6.0050425f-5 -8.611959f-5 -3.098497f-5 0.00015489353 -2.9034496f-5 -5.005263f-5 -2.791593f-6 8.991841f-5 7.092545f-5 1.7887567f-5 -5.7485697f-5 0.00012185162 -4.976333f-5 -4.5398945f-5 4.906926f-5 -0.00016804886 0.00014117859 0.000115827606 0.0001781948 -6.0217676f-6 -0.00020120277 9.3449926f-5 -0.00012868166 7.748491f-6 9.6867276f-5 -1.0356799f-5 4.9943963f-5 -5.306781f-5 -9.037833f-5; 2.9837918f-5 -7.255863f-5 9.412587f-5 -1.7601298f-5 -0.000114333954 6.476155f-5 -8.582583f-5 -9.716679f-5 8.722982f-5 -6.230363f-5 7.7950026f-5 -3.951483f-5 0.00015032693 -0.00010361166 -1.737396f-5 -0.000115476105 3.2510947f-5 -2.9110712f-5 -2.643279f-5 0.0001071241 3.3524426f-5 -0.0001190942 2.9478178f-5 -0.0001301822 6.7909335f-5 3.2849093f-5 -1.0268766f-5 0.00014633984 -8.7798384f-5 -9.1677706f-5 4.483881f-5 7.2872564f-5; 2.9319881f-5 0.00011686574 -4.43873f-5 -0.00016967976 -0.00012907636 8.3250364f-7 8.098582f-5 5.0018007f-5 -1.5253242f-5 1.0527288f-5 4.598014f-5 -9.414437f-5 -9.112943f-5 -1.9238882f-5 -0.000110859444 2.868265f-5 -0.00016603517 4.310041f-5 1.7687706f-5 -0.00021716734 -8.253676f-5 -7.573336f-5 -3.3934215f-5 -6.098819f-5 -7.083523f-5 -0.00016590476 -5.481711f-6 6.213844f-5 -0.00011265972 8.0679674f-5 -9.703154f-5 7.052041f-5; 3.2111606f-5 1.8854622f-5 -6.611151f-5 -2.438633f-5 -0.00017871796 1.1277335f-5 -5.280775f-5 -3.1474683f-5 -0.0002106705 0.0002279261 -4.6759806f-6 3.3446984f-5 -0.00020480774 1.0127108f-6 -2.0451593f-5 -2.4010857f-5 -4.4143508f-5 3.3145807f-5 -8.5729705f-5 7.039029f-5 0.00017737923 9.8911616f-5 5.1827083f-5 -5.4079745f-5 -4.3784396f-5 2.9792465f-5 0.00019867427 -3.289746f-5 0.00012539676 7.834664f-5 0.00013388516 0.00014167732; -1.5164823f-5 0.00015942726 1.6939977f-5 -5.4768225f-5 -8.4090425f-5 6.265183f-5 -6.564695f-5 -9.541208f-6 2.3728444f-5 8.084692f-6 -0.000121885496 2.831258f-5 -4.33191f-5 3.7869744f-5 0.000116095296 -1.1135479f-5 0.00015486899 0.00021054299 4.177613f-5 2.975734f-5 2.2294065f-5 -4.6847712f-5 -5.6245342f-5 -1.2739395f-5 -0.00016180114 -0.00012975749 -0.00013111386 1.8129978f-5 9.4502815f-5 -5.6972407f-5 -0.0001421553 -6.671062f-5; 2.3260605f-5 6.89441f-5 2.081883f-5 3.6927435f-5 -8.031211f-5 -9.943137f-5 6.600796f-6 0.00011862577 0.00024392665 0.000101471385 0.00024618822 -0.00010740845 0.00012898735 -5.2777337f-5 -8.048195f-5 -8.626386f-5 -2.9888442f-5 9.094375f-5 -7.660887f-5 3.3503613f-5 -0.00022864119 -3.7014317f-5 -0.00014431294 -2.5558884f-5 1.4219322f-5 4.5675413f-5 -2.4883348f-5 -0.00017992257 1.9669966f-5 2.9790204f-5 -2.3641034f-5 0.00016408076; 9.403786f-5 -4.4076234f-5 6.9454013f-6 -0.00010516141 1.7729375f-6 -8.675574f-5 0.00016346767 -0.00013672143 6.352472f-5 2.0652906f-6 8.986553f-5 -2.130596f-5 -3.6750527f-5 5.6753288f-5 6.3803454f-6 -1.31084225f-5 8.4143096f-5 6.130523f-5 4.75049f-5 6.20976f-5 7.2764225f-5 -0.00011773905 3.6892423f-5 -0.00014151524 -1.9590658f-5 0.00010307646 1.4008845f-5 -9.0011294f-5 8.825301f-5 7.125841f-5 6.13784f-5 -9.689259f-5; 5.3421063f-5 -8.2296196f-5 -3.199408f-5 -0.00010109525 9.999054f-5 -5.12995f-5 -6.5050444f-5 -8.0021724f-5 -6.777634f-5 -5.4590495f-5 2.0317675f-5 -8.2216146f-5 9.968462f-5 -1.2015374f-5 -0.00013009872 -3.19251f-5 -0.0001057372 2.2168044f-5 0.0001554913 6.243841f-5 6.213586f-5 -0.00012006222 5.7858495f-5 -8.325727f-5 -6.221369f-5 -4.0002713f-5 0.0001517974 -0.00023155552 1.1117637f-5 6.716051f-5 -2.1891357f-5 -2.723912f-5; 9.931639f-5 -4.7830894f-5 8.996152f-6 -1.3641983f-5 9.356157f-5 4.2601718f-5 2.8264465f-5 -0.00013694252 7.830424f-5 -0.0002437523 -0.0001077839 -6.81817f-5 -0.00015298155 -7.2717245f-5 5.7389087f-5 -7.146371f-5 4.8111193f-5 0.00012393533 1.1704835f-5 0.00020219543 2.987786f-6 0.00012707798 -9.916982f-5 0.0001942762 -9.388106f-5 -0.00021244626 6.387846f-5 -6.020016f-5 0.0001438062 2.2277798f-5 7.557829f-6 5.1247825f-5; 0.00015167803 -7.3655974f-5 -0.00012964515 -3.9354018f-6 -1.7032286f-6 -3.1642365f-5 7.2759234f-5 0.00013016719 -8.1833416f-5 -0.00016899231 -6.9762077f-6 6.499807f-5 0.00017035355 2.7823613f-5 0.00012524617 -2.852904f-6 -0.00011721244 0.00013792295 0.00016224578 -3.1066164f-5 -1.5898298f-5 0.00018744648 1.1189754f-5 -2.1239905f-6 -5.0808852f-5 -4.0183113f-5 -0.00014370435 -3.500525f-5 6.524747f-5 -1.9584122f-5 -0.00012962104 -3.9297493f-6; 8.424734f-5 6.2431165f-5 -6.978376f-5 0.00020519257 5.3609324f-6 0.0001063012 1.5153335f-5 0.00030814004 -2.6589005f-6 6.717549f-5 -0.00013855234 1.1398261f-5 3.1570904f-5 8.528864f-5 9.750611f-5 4.145425f-5 4.033533f-5 6.8058245f-5 -8.360228f-5 -1.480661f-5 -5.1167102f-5 -1.7953427f-5 0.00012767002 7.2751616f-5 -3.8605784f-5 -0.00014634528 -3.874666f-6 -6.383116f-5 8.333956f-5 -0.000120729215 -5.2301806f-5 0.00015253172; -7.8951125f-5 -9.1755035f-5 4.507609f-5 0.00012195797 0.00016450198 -5.393895f-5 1.2600818f-5 -0.00013411947 -3.316781f-5 0.00016616723 1.09566445f-5 1.435772f-5 2.1524538f-5 -0.00010811424 0.000108764245 6.211263f-5 -7.8744226f-5 -0.00013591084 1.912014f-6 -4.341837f-6 9.539013f-5 -4.759713f-5 -8.8073255f-5 -9.1503935f-5 -0.00020645172 -5.1497154f-5 3.7138714f-5 6.14159f-6 -2.5749052f-5 -8.855554f-5 4.1551f-5 0.00019873805; -2.284099f-5 -0.00015546716 -2.913793f-5 -7.6517335f-6 -9.593639f-5 -1.0106432f-5 5.192085f-5 9.074366f-5 -7.6425305f-5 6.891297f-5 -9.0353416f-7 0.0003136349 -0.00017960517 -0.0001451714 0.00011357102 0.00023227517 -4.9838025f-5 8.012792f-5 -2.7347132f-5 -0.00010349574 -6.72592f-5 -6.828263f-5 4.4530134f-6 -5.7241217f-5 9.666101f-5 -4.988788f-5 3.272934f-6 -0.000201955 -1.8346665f-5 0.00014707468 -7.8083576f-5 -0.000117233925; 0.00013769133 -0.00016804777 0.00013072582 -6.65752f-5 -5.965959f-5 -3.905759f-5 7.638606f-5 3.6936024f-5 1.5003309f-5 -1.759391f-5 -9.228748f-5 0.00012982362 9.148531f-5 0.00013114669 6.14878f-5 0.00013527395 8.609052f-6 -5.58521f-5 -5.200514f-5 -5.3033258f-5 -6.852413f-5 -1.420722f-5 7.33058f-5 -0.0002823781 5.5021617f-5 -0.00017765221 9.74951f-5 -0.000121507925 0.0001638782 2.5399691f-5 0.00016329689 0.00014361081; 0.000113591916 0.00021730621 -3.5137175f-5 0.00018872265 0.00013970047 -5.1350347f-5 -7.2089155f-5 2.730251f-5 -0.00014521192 4.418273f-5 -0.00016832145 -9.428327f-5 0.00014792163 0.00012892298 0.00013376809 0.00016895238 7.096532f-5 7.9260324f-5 3.3363896f-5 0.00017995 0.00014199728 1.8265186f-5 -0.00017072634 1.1911179f-5 -5.2064315f-5 -9.820823f-5 8.366682f-5 5.211058f-5 5.9973325f-5 -1.4197139f-5 -9.050878f-5 8.306328f-5; -0.00019106924 4.4481825f-5 -2.1274296f-5 -2.5730782f-5 -1.0449693f-6 8.499945f-6 -0.00017555348 0.000103235914 -4.2313943f-5 -0.00010645989 0.00014762556 -0.00020561791 -3.3301087f-5 -0.00013462303 8.503109f-6 7.66913f-6 -4.463184f-5 -4.5707566f-5 0.00013088416 8.2189035f-5 -2.097016f-5 1.7603437f-5 -5.636486f-5 6.6790766f-5 -1.7496257f-5 1.9211047f-5 3.561832f-5 -0.00017748182 -4.2172924f-5 7.139278f-5 -2.4804809f-5 -4.8669124f-5; 2.9354524f-5 1.2971316f-5 -0.00013705278 -6.616729f-5 4.6683726f-5 -3.2751723f-5 6.149603f-5 -0.00016187284 2.2368617f-5 3.9578023f-5 -3.8018916f-5 7.517271f-5 -0.000101751415 -5.639251f-5 1.320724f-5 6.978893f-5 -4.0895997f-5 -5.3062093f-5 0.00016774722 -0.00015073478 -0.00012849121 -0.00012396455 -4.008406f-5 0.00013407599 -0.0002011169 -0.00020750833 9.101726f-5 -0.00011864815 0.00022074179 -0.00017107456 -4.7937057f-5 -0.0001548744; 2.6038186f-5 -0.00019728432 3.0433683f-5 4.7084977f-6 0.00012734313 0.00015395942 -0.00011012476 6.148991f-5 -0.00011764627 9.819931f-6 -3.1553314f-5 -4.4313874f-5 5.4239645f-5 -0.00014618358 -9.517221f-5 3.474464f-5 3.0673746f-5 8.915593f-5 -0.00012393115 -5.815107f-5 2.0972679f-5 -9.002489f-5 -9.5474525f-6 -3.5325143f-6 -9.752899f-5 2.2631326f-5 -0.00020261828 -5.2361927f-5 -0.00011520634 4.4655044f-5 0.00014881688 2.9522576f-5; 0.00015230672 2.6964066f-5 -0.00010536315 -7.212125f-5 0.00017887703 -0.00019535322 5.65366f-5 -0.00015461333 -3.0680792f-5 0.00018808684 -0.00012573904 -0.00013093844 -5.0373587f-6 2.5973182f-5 -0.0001744976 3.175685f-5 -2.4651257f-5 0.00012419668 1.7228553f-5 -0.00011667235 -1.3878133f-5 0.00012117408 0.00022775444 1.9097148f-5 -0.00012741654 5.666602f-5 0.00012424763 2.9739975f-5 5.5982662f-5 -7.006837f-5 -3.8190068f-5 1.3198603f-5; -0.00016249737 -3.144418f-5 -3.7708665f-5 -1.9125087f-5 3.5387326f-5 -9.924152f-5 -3.6599642f-5 -9.0765534f-5 -7.248977f-5 3.5338944f-5 -0.00016449956 9.285638f-5 -0.00015069522 -1.4366385f-5 -0.00012104479 0.00010373411 2.8423815f-6 7.9470665f-6 3.316501f-5 -0.00012307067 -5.259518f-5 1.4004394f-5 -6.035464f-5 3.891897f-5 0.00012527934 0.00012892442 6.933846f-5 7.9278994f-5 -7.8422505f-5 7.20488f-5 0.00012982442 -5.5622582f-5; -8.238646f-5 -9.1801754f-5 -0.00023732118 -4.466503f-6 1.24345315f-5 -5.51373f-5 5.4471617f-5 7.2063965f-5 0.000106306565 -9.100764f-5 -3.428378f-5 -0.00013665373 -5.03951f-5 0.00017883761 4.9746337f-5 1.4724558f-6 8.153369f-5 0.00016250524 -1.641285f-5 -3.7329046f-5 9.162754f-6 3.1936095f-5 1.0533314f-5 -3.075185f-5 4.4001572f-5 0.00012857665 -0.00011185639 -3.2276603f-5 -5.3121867f-6 -6.0836355f-5 0.00017337587 -3.230821f-5; 8.590876f-6 9.804393f-5 -3.8537364f-5 -6.4751424f-5 -0.00012736092 3.466122f-6 0.00010525993 -4.1479427f-5 8.178095f-5 -0.00015544175 -0.00012600339 -4.872348f-5 -6.123252f-6 8.66736f-5 -9.93985f-5 0.000105560575 -2.0395528f-5 -5.27003f-5 -7.348807f-5 -0.00013767814 5.7389036f-5 -5.069902f-5 0.00013055989 -4.495004f-5 0.0001126314 -1.3647534f-5 -4.567861f-5 1.571933f-5 5.224656f-5 5.084517f-5 -3.2681157f-5 -1.7932474f-5; 0.00013921711 -0.00015577646 -6.8523223f-6 -3.3377448f-5 7.970168f-5 -8.0961065f-5 3.744849f-5 4.3635842f-5 3.5754136f-5 0.00010316635 3.9413113f-5 -8.58068f-6 9.983374f-6 6.2374675f-5 -5.791865f-5 -3.5828445f-5 -0.00014877273 7.776998f-5 4.5482397f-5 0.000103311984 -9.1263646f-5 2.8277214f-5 -0.00016701088 0.00019268872 -8.207328f-6 9.913214f-5 -9.774391f-5 0.00020677464 -5.9293416f-5 2.3474828f-5 5.8683887f-5 -3.8914122f-6; 4.774461f-5 9.519678f-5 1.1256554f-5 -0.00012252961 0.00011983403 3.5773213f-5 0.00012217183 8.2566044f-7 0.00010411601 0.00014272046 -1.569555f-5 -5.8355614f-5 9.8246885f-5 -5.782332f-5 8.410353f-5 0.000109867105 5.702376f-5 -6.918473f-6 -0.00015956594 0.00010045738 4.5218094f-6 7.244497f-5 7.10122f-5 6.0736318f-5 1.45325785f-5 6.7512374f-5 4.2152886f-5 -7.413285f-5 1.2950177f-5 0.00015473386 -7.821901f-5 3.1387597f-5; 0.000116862895 3.142181f-5 0.00016147489 -2.245337f-5 0.00011100689 2.7629023f-5 0.00015894434 -1.4032022f-5 0.00012562197 -0.00017124297 -5.149673f-5 0.00026991326 2.3453172f-6 0.000111917114 -1.3733707f-5 -1.9395986f-7 3.406877f-6 -0.00012943553 -9.811766f-5 -4.3762433f-5 3.468921f-5 2.113784f-5 -1.7946519f-5 0.00019157273 -6.948297f-5 -0.00013418884 0.00018036284 5.742438f-5 2.81473f-5 1.0579332f-5 -7.04458f-6 6.223718f-5; 9.13669f-5 1.1281184f-5 3.4757908f-5 2.6898888f-5 -0.000107047774 -0.00013932401 -4.3365082f-5 9.912031f-5 -7.812127f-5 -6.918816f-5 -3.6009708f-6 -7.410791f-5 -0.00012665565 9.7293014f-5 -3.0730007f-5 -1.2964656f-5 -9.444532f-5 -4.244437f-5 1.3376468f-5 1.3280838f-5 -0.00014538367 0.0002165357 5.082605f-5 0.0001519198 0.0001646723 -6.6752305f-5 -1.4838044f-5 -2.3803124f-5 -8.3464256f-5 -0.00012618655 2.575532f-5 -0.00014172026; 0.000109586515 -2.4938046f-5 5.7840585f-5 6.6466446f-6 4.7974f-5 7.347498f-5 0.0001920108 1.2714904f-5 -0.0001318789 -8.946685f-5 -9.6895914f-5 0.00021863206 9.089315f-5 0.000103806844 6.43882f-5 -9.537681f-6 -5.3869502f-5 3.610747f-5 0.00018722474 -2.583136f-5 4.3573622f-5 1.622268f-5 -9.880133f-5 5.5604774f-5 0.00010152891 1.037074f-5 8.330936f-6 1.7616152f-5 4.2478347f-5 -0.00016588849 -9.274746f-6 9.7204604f-5; -4.9133418f-5 -0.0001294212 3.4467706f-5 -0.0001358674 -2.9464387f-5 2.1873651f-5 -5.5269747f-6 8.600474f-6 9.754537f-5 4.6074925f-5 4.6623238f-5 7.293788f-5 3.505186f-5 -8.252996f-5 -4.7919362f-5 -6.381261f-5 -1.9057974f-6 3.184339f-5 3.4257024f-5 -1.5527862f-6 7.911285f-5 0.0001343814 7.614728f-5 0.0003527654 -1.282295f-5 4.4400716f-5 -0.00024189803 0.0002952487 4.2873005f-5 9.6747804f-5 -7.955383f-5 0.00014243607], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[9.628103f-5 9.958568f-5 -5.5717624f-5 8.037213f-5 0.00010825781 -0.00015822238 -0.00013828979 -0.00012648081 -4.9432772f-5 -0.00021008535 0.00020567847 1.5663712f-5 -5.805276f-5 -0.00012478585 1.6882293f-5 -0.00024574762 4.2240034f-5 1.399525f-5 0.000118214266 2.9404599f-5 -8.62307f-5 2.8780047f-5 0.00018615303 8.8221575f-5 -3.5317717f-5 1.0045078f-6 1.3102576f-5 2.399451f-5 0.00012852579 -0.00017115971 -0.00011584318 1.4456587f-5; 0.00015472209 -2.6665924f-5 9.585035f-6 6.628315f-5 0.00010447536 -4.1016494f-5 -5.0457493f-5 -0.00013691185 0.00018360621 9.302042f-5 4.3921475f-5 -8.955095f-5 2.308109f-5 -5.7305246f-5 2.135259f-5 0.00015124868 -2.6914055f-5 -7.122205f-5 5.8325837f-5 -0.0001688684 8.728442f-6 0.000100440375 -0.00018964396 -0.0003080128 2.5957992f-5 -2.8137405f-5 -2.6460486f-5 3.8681108f-5 0.00014545857 -1.8427505f-5 -5.5887136f-5 8.543386f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))Similar to most deep learning frameworks, Lux defaults to using Float32. However, in this case we need Float64
const params = ComponentArray(f64(ps))
const nn_model = StatefulLuxLayer(nn, nothing, st)StatefulLuxLayer{Val{true}()}(
Chain(
layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
layer_2 = Dense(1 => 32, cos), # 64 parameters
layer_3 = Dense(32 => 32, cos), # 1_056 parameters
layer_4 = Dense(32 => 2), # 66 parameters
),
) # Total: 1_186 parameters,
# plus 0 states.Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses
where,
function ODE_model(u, nn_params, t)
χ, ϕ = u
p, M, e = ode_model_params
# In this example we know that `st` is am empty NamedTuple hence we can safely ignore
# it, however, in general, we should use `st` to store the state of the neural network.
y = 1 .+ nn_model([first(u)], nn_params)
numer = (1 + e * cos(χ))^2
denom = M * (p^(3 / 2))
χ̇ = (numer / denom) * y[1]
ϕ̇ = (numer / denom) * y[2]
return [χ̇, ϕ̇]
endLet us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
axislegend(
ax,
[[l1, s1], [l2, s2]],
["Waveform Data", "Waveform Neural Net (Untrained)"];
position=:lb,
)
fig
endSetting Up for Training the Neural Network
Next, we define the objective (loss) function to be minimized when training the neural differential equations.
const mseloss = MSELoss()
function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
return mseloss(pred_waveform, waveform)
endWarmup the loss function
loss(params)0.0007072778856636132Now let us define a callback function to store the loss over time
const losses = Float64[]
function callback(θ, l)
push!(losses, l)
@printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
return false
endTraining the Neural Network
Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
optprob,
BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
callback,
maxiters=1000,
)retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-9.199089981833302e-5; -4.4559576053859756e-5; -2.7491587388776448e-5; 6.290053715924689e-5; -3.6213270504914916e-5; -0.00013411113468458962; 7.257135439419504e-5; 1.4270464134813745e-6; -0.0001296575501327329; -5.72737844777062e-5; -3.5657078115046794e-5; -0.00020835806208172958; 4.5357606722895475e-5; 2.928108006011229e-5; 0.0002205852797484001; 8.454694761892526e-5; -2.900094659705948e-5; -0.00010089229908767598; -9.882357699110179e-5; -9.148981916928102e-5; -2.8624448532315465e-5; -0.00021886563627015128; -1.8664666640687395e-5; 6.0547652537871666e-5; 0.00011555568926263112; -5.002941907146932e-5; 1.4415855730476081e-5; -7.385066419363876e-5; -0.00018664482922730298; -0.00015415619418489233; 8.601636363893853e-5; -2.4320150259825416e-5;;], bias = [-1.6979237463734911e-16, -1.1693346567401251e-17, -2.0714181083352786e-17, 5.376550035188711e-17, -4.906723829081721e-17, -6.338977227867032e-17, 1.3237212766390795e-16, 1.2309703422710318e-18, -1.1052983760577275e-16, -2.702744555932343e-17, 1.1787761543991422e-17, -2.7355511801934445e-16, 8.802826165414905e-17, 6.792349928226234e-17, 4.237638890831353e-16, 2.2498584164330148e-17, -2.3213311641808847e-17, -4.837795687098545e-17, 1.3636165954765412e-16, -2.576438457188304e-16, -4.455763810708869e-17, -3.864015573518605e-16, -1.509708203345852e-18, 6.958576805305756e-17, 7.00192154418249e-17, -4.2986194524094044e-17, -1.3015559828139428e-20, -1.1544283927376042e-16, -3.21797698563033e-16, -1.42627961546394e-17, -2.0791016602500844e-17, -6.392997157052292e-17]), layer_3 = (weight = [8.831261256505505e-6 0.00020744095530922127 -8.830944130800217e-5 2.1460597453780728e-5 -7.643503482590874e-5 -0.00010209305179208692 -1.0922122383405656e-5 4.04921287913101e-5 -7.381039318764843e-5 0.00014566269095617047 -0.00020632408410215268 5.3506219446247204e-5 0.00011588638011694495 -8.592908994664915e-5 4.5170915145251716e-5 -1.832759295261489e-5 0.00011615489933269551 -4.5783058849831975e-5 -6.683670068063609e-5 -0.0001361647896924701 -0.00014403950041064542 0.00014925828003423401 -4.0531148952642005e-5 0.00012162415663574548 0.0001658034438509816 -0.00011682288601200456 -4.652578860309499e-5 -0.00013943064692334417 -9.683585954964988e-5 -6.104581694720073e-5 -4.5120928517049254e-5 -0.000164550890586316; 6.88024229304013e-5 4.004975929195494e-5 -6.994122912934819e-6 6.975603903914412e-5 -8.546144815124176e-5 5.734452864193743e-5 -0.00019147831332599132 1.7520993452429284e-5 -1.0983996588817382e-5 -1.4535479721293858e-5 -3.507922880895973e-5 -3.6126300955257154e-5 1.639463155800914e-5 -0.00019153639001967017 -0.00012208683617139944 -8.876829085135114e-5 -2.226128341941724e-5 0.00011609419746595388 6.707772266157833e-5 -0.00010380580349088792 5.215456468371172e-6 -3.147797539789985e-5 0.00010139683584487407 -3.863708113282393e-5 -7.856831504692192e-5 1.5080189797240273e-5 0.00021030734208682885 0.00012278630313734213 -4.729201825179706e-5 -8.614909162938819e-5 -2.8062766404093373e-5 -0.00012731349121974426; -0.00019290063957596302 -4.416896189648964e-6 1.7916805318965508e-5 5.148444668593715e-5 -0.00019748229553369907 7.896958393830698e-5 4.5029922858522184e-5 0.00010367848861351509 -0.00019733872633805605 0.0001766565981422913 -0.00010151906482787903 -0.00018269791345955922 -5.981714295485157e-5 5.056678108376351e-5 -9.560000782512253e-5 0.00011686419000470907 -3.563284937554139e-5 6.893820937695102e-6 -5.321646332886174e-5 0.00011983553829923168 6.510569053789501e-5 0.0001385629582553531 -2.544207490426321e-6 0.0001422699859002036 1.9917111586256757e-5 -5.241330674762007e-5 -1.8178583758707227e-5 -4.1441279080753304e-5 6.486040345480548e-5 2.1808878755838433e-5 -6.745522134151414e-5 8.602863070409645e-5; -1.6838297115702677e-5 -6.0890526102109436e-5 0.0002071803160373126 6.873965915413848e-5 8.072870995722139e-6 6.668954295905112e-7 1.1173272950450398e-5 7.922704983623685e-5 -0.0001401099156925334 -2.9745953243642652e-5 -0.0002377765062835059 8.887014567277745e-5 -8.969521283749356e-5 -4.754475728138301e-5 -9.466971229870072e-5 -9.277693125716932e-6 9.422581072552176e-5 -9.072749659995751e-5 0.00014492435596306574 -0.00021853183122462927 6.302050908682913e-5 -0.00014847979898207145 -6.538713339129451e-5 9.374541631890668e-6 -0.00013655347126646875 -0.00012609470420162956 4.2276502908254056e-5 8.818044309858614e-5 -0.00017586971941513282 0.00011295667948556916 3.656255878862518e-5 5.2111309313960665e-5; -3.873583338495315e-5 6.521220419167663e-5 -3.000062790156202e-5 -6.004997996419575e-5 -8.611914310526902e-5 -3.0984526545900075e-5 0.00015489396972193456 -2.9034051493316308e-5 -5.005218483197595e-5 -2.7911484888581443e-6 8.991885332002438e-5 7.09258972907671e-5 1.788801162697496e-5 -5.748525219892013e-5 0.00012185206358630591 -4.9762885481277e-5 -4.539850053183596e-5 4.9069702783208385e-5 -0.0001680484194533619 0.00014117903700175725 0.00011582805089497147 0.00017819523853326608 -6.021323059347629e-6 -0.00020120232440453854 9.345037031530359e-5 -0.00012868121049755137 7.748935587780167e-6 9.686772025966158e-5 -1.0356354522040368e-5 4.994440749364714e-5 -5.3067367133143844e-5 -9.037788478812991e-5; 2.9837364668607522e-5 -7.255918226717269e-5 9.412531598657119e-5 -1.7601851691028175e-5 -0.00011433450763859097 6.476099897129521e-5 -8.582638165960258e-5 -9.71673440332551e-5 8.722926522325764e-5 -6.230418567126823e-5 7.79494725023372e-5 -3.9515384382563116e-5 0.00015032637692941124 -0.00010361221266491589 -1.7374514395364933e-5 -0.00011547665836103396 3.2510393243979025e-5 -2.9111265216372785e-5 -2.6433343559504436e-5 0.00010712354505387952 3.352387229399505e-5 -0.00011909475197595874 2.9477624953256198e-5 -0.0001301827584072036 6.790878196200332e-5 3.28485397361197e-5 -1.0269319868416344e-5 0.0001463392831240764 -8.779893742454805e-5 -9.167825974570425e-5 4.4838255594483194e-5 7.287201096489771e-5; 2.931567499005025e-5 0.00011686153404994235 -4.43915055005612e-5 -0.00016968396127808272 -0.0001290805671236271 8.282977116111226e-7 8.098161154308173e-5 5.001380079888765e-5 -1.525744798739205e-5 1.0523082462581908e-5 4.597593536132216e-5 -9.41485782025529e-5 -9.113363237006282e-5 -1.9243087510721256e-5 -0.00011086365018152454 2.8678444803208862e-5 -0.00016603937590148153 4.309620228711076e-5 1.7683499688850403e-5 -0.00021717154638494122 -8.254096653431515e-5 -7.573756781519315e-5 -3.393842095861138e-5 -6.099239587669664e-5 -7.083943680511793e-5 -0.00016590896163720266 -5.485916947100233e-6 6.213423287377536e-5 -0.00011266392582210565 8.067546835053272e-5 -9.70357436675288e-5 7.051620152892652e-5; 3.211305753631191e-5 1.885607358837479e-5 -6.611006013861824e-5 -2.4384878848089113e-5 -0.00017871651274197723 1.1278786737065917e-5 -5.2806299059178574e-5 -3.147323135215717e-5 -0.00021066905522765383 0.00022792754695239397 -4.674528859488076e-6 3.3448435865099346e-5 -0.000204806292308951 1.0141624576304759e-6 -2.0450141014776996e-5 -2.400940487440009e-5 -4.414205599039759e-5 3.314725851361232e-5 -8.572825375757568e-5 7.03917432152727e-5 0.00017738068360324054 9.891306778135385e-5 5.18285343465494e-5 -5.4078292883225796e-5 -4.378294382639787e-5 2.9793916987380492e-5 0.0001986757235177892 -3.289600848382523e-5 0.00012539821560913055 7.83480901778664e-5 0.000133886609696324 0.0001416787673994065; -1.5165515496376935e-5 0.0001594265689805205 1.6939284587819678e-5 -5.476891749775376e-5 -8.409111763238374e-5 6.265113767814802e-5 -6.564764192215474e-5 -9.541900571535433e-6 2.3727751222867046e-5 8.083999428449852e-6 -0.00012188618859862377 2.831188646316618e-5 -4.331979437618897e-5 3.7869051421252645e-5 0.00011609460340926165 -1.1136172024165613e-5 0.00015486829795055614 0.00021054229579978428 4.1775436666754803e-5 2.9756647900754996e-5 2.229337257678211e-5 -4.6848404626392704e-5 -5.6246035118861605e-5 -1.2740087758615802e-5 -0.0001618018374115148 -0.0001297581853846032 -0.00013111455485112456 1.812928564551448e-5 9.450212234744231e-5 -5.69730997353688e-5 -0.000142155995153629 -6.671131416577329e-5; 2.326103128672683e-5 6.894452633225089e-5 2.0819256291196028e-5 3.692786083212756e-5 -8.031168147334158e-5 -9.943094085081973e-5 6.601222024817594e-6 0.00011862619828554559 0.00024392707962198637 0.00010147181126177953 0.00024618865097529005 -0.00010740802429262465 0.0001289877731085821 -5.277691127610646e-5 -8.048152415193608e-5 -8.626343662961728e-5 -2.9888016457499672e-5 9.094417725811327e-5 -7.660844460240945e-5 3.3504038923472304e-5 -0.00022864076493186197 -3.701389107499252e-5 -0.00014431251804688264 -2.5558457879180218e-5 1.4219748094380275e-5 4.567583925910949e-5 -2.4882921594490876e-5 -0.00017992214600433913 1.9670391678858712e-5 2.9790630263792204e-5 -2.3640608220541573e-5 0.000164081186425227; 9.403895808114515e-5 -4.407513563441036e-5 6.9464998652274116e-6 -0.00010516031045502978 1.7740360355326082e-6 -8.675464062056995e-5 0.00016346877225638327 -0.00013672032957932767 6.352581829028685e-5 2.0663891283774986e-6 8.986662567091275e-5 -2.130486075280229e-5 -3.674942866003999e-5 5.6754386491012994e-5 6.381443994672096e-6 -1.3107323980730629e-5 8.414419414672185e-5 6.130633107582824e-5 4.750599680315503e-5 6.209869741192865e-5 7.276532403390039e-5 -0.00011773795279246442 3.6893521495732896e-5 -0.00014151413790917938 -1.9589559202216107e-5 0.00010307756008311231 1.4009943583152718e-5 -9.001019509545576e-5 8.82541061165727e-5 7.125950639797609e-5 6.137949810559744e-5 -9.689149021672892e-5; 5.3418678258534897e-5 -8.229858041370004e-5 -3.19964657267723e-5 -0.00010109763217253175 9.998815392980582e-5 -5.1301884117772525e-5 -6.505282919017061e-5 -8.002410878755022e-5 -6.777872305456985e-5 -5.459287978221278e-5 2.0315290281067095e-5 -8.221853032799684e-5 9.968223629191805e-5 -1.2017758412477929e-5 -0.00013010110534585854 -3.19274860106061e-5 -0.00010573958399101897 2.2165659070820392e-5 0.00015548891850420226 6.243602718830885e-5 6.213347831879004e-5 -0.0001200646057571088 5.785611018443637e-5 -8.325965437899863e-5 -6.221607155958892e-5 -4.000509772078743e-5 0.00015179501668688053 -0.0002315579023291488 1.111525206006855e-5 6.715812367992575e-5 -2.1893742190801976e-5 -2.724150555361352e-5; 9.93164922445828e-5 -4.783079379845372e-5 8.996252338733015e-6 -1.3641882967163954e-5 9.356166815709364e-5 4.260181781577864e-5 2.8264565164903455e-5 -0.00013694241529006466 7.830434156207359e-5 -0.0002437522070496526 -0.00010778379687249619 -6.818159631362282e-5 -0.0001529814489432793 -7.271714452795642e-5 5.7389186670518285e-5 -7.146361344635464e-5 4.8111292974539533e-5 0.0001239354297038914 1.170493492540451e-5 0.0002021955279386418 2.987886139493626e-6 0.00012707807586852346 -9.916971789348827e-5 0.00019427630291178854 -9.388096254693341e-5 -0.00021244616250183396 6.387856233490675e-5 -6.020005998576032e-5 0.00014380630277886918 2.2277898145005708e-5 7.557929214129897e-6 5.124792556020475e-5; 0.00015167884483790827 -7.365516140733806e-5 -0.00012964433918849497 -3.934589051902424e-6 -1.7024158642799134e-6 -3.1641552062038456e-5 7.276004688438911e-5 0.0001301680000205254 -8.183260289396511e-5 -0.00016899149560512037 -6.975395002051375e-6 6.499888476073615e-5 0.00017035436538836955 2.7824425694156195e-5 0.00012524697884773572 -2.852091177453796e-6 -0.00011721162665079278 0.00013792376338361584 0.00016224659119590018 -3.1065350788674694e-5 -1.5897485730498334e-5 0.00018744729239815602 1.1190567044572336e-5 -2.123177785585338e-6 -5.0808039680386715e-5 -4.0182300302227695e-5 -0.00014370354111914438 -3.500443782567099e-5 6.52482827598798e-5 -1.958330941587758e-5 -0.0001296202266649633 -3.928936542330015e-6; 8.42506378835027e-5 6.243446478080295e-5 -6.978046076213519e-5 0.00020519586905762632 5.364232038541363e-6 0.00010630450196953932 1.5156634618015346e-5 0.0003081433351330509 -2.655600837491395e-6 6.717878932291793e-5 -0.00013854903590939573 1.1401561096513607e-5 3.1574203604090056e-5 8.529193655746469e-5 9.750941171697287e-5 4.1457547923035156e-5 4.0338629333259936e-5 6.806154488045134e-5 -8.359898129087486e-5 -1.4803310477458156e-5 -5.116380270877439e-5 -1.795012758483617e-5 0.00012767332009431439 7.275491589140257e-5 -3.860248398227614e-5 -0.00014634197941589864 -3.871366515826517e-6 -6.382786159478209e-5 8.334286030769846e-5 -0.00012072591562133465 -5.229850648860737e-5 0.00015253501987942003; -7.895188348486591e-5 -9.175579355217249e-5 4.5075330557276075e-5 0.00012195721326276976 0.00016450122172825292 -5.393970962033871e-5 1.260005962974276e-5 -0.00013412023138373872 -3.316856775784226e-5 0.00016616647014741154 1.095588602181308e-5 1.4356961467986923e-5 2.152377971794726e-5 -0.00010811499844424189 0.00010876348649033847 6.211187253354515e-5 -7.874498435415089e-5 -0.0001359116012521786 1.9112556615931325e-6 -4.34259566320727e-6 9.538937417671322e-5 -4.759788673838116e-5 -8.80740134802438e-5 -9.1504692978949e-5 -0.00020645247598297348 -5.149791279693785e-5 3.713795565415397e-5 6.1408314409829236e-6 -2.574981038052897e-5 -8.855630033069408e-5 4.1550241870548885e-5 0.00019873729440730297; -2.284237708110628e-5 -0.0001554685488433493 -2.9139316399527316e-5 -7.653120707548247e-6 -9.593777686561578e-5 -1.0107818713291326e-5 5.191946437960547e-5 9.074227527278295e-5 -7.642669261869401e-5 6.891157935867985e-5 -9.049213251491102e-7 0.0003136335097163255 -0.00017960655278031803 -0.00014517279233289095 0.00011356963186090024 0.00023227378077834054 -4.9839412173900305e-5 8.012652942241446e-5 -2.7348519512742964e-5 -0.00010349712575291088 -6.726058673550516e-5 -6.828401565753745e-5 4.451626221703583e-6 -5.72426044588817e-5 9.665962242550637e-5 -4.988926703547195e-5 3.271546841711426e-6 -0.00020195638827019309 -1.8348052623612e-5 0.00014707329209458038 -7.80849633945044e-5 -0.00011723531243362617; 0.00013769289888944063 -0.00016804620155096134 0.00013072739183474312 -6.657363029159786e-5 -5.965801811623434e-5 -3.905602044040301e-5 7.638762998897335e-5 3.693759516966548e-5 1.5004880325359785e-5 -1.7592338380265587e-5 -9.228591134967327e-5 0.00012982518764247792 9.148687912886877e-5 0.00013114826232697791 6.148937190219167e-5 0.0001352755201797898 8.61062325487123e-6 -5.58505276866307e-5 -5.2003568824944195e-5 -5.303168710167926e-5 -6.852256102217385e-5 -1.4205649321317631e-5 7.330736788492651e-5 -0.00028237653951589185 5.5023188352717935e-5 -0.00017765064022466648 9.749667459903294e-5 -0.00012150635357909162 0.00016387976883480435 2.540126242012459e-5 0.00016329846347719074 0.00014361238517370065; 0.00011359632237655647 0.00021731061617686007 -3.5132767765729175e-5 0.00018872706034528337 0.00013970487391418471 -5.134594055052739e-5 -7.208474805736854e-5 2.7306917514241025e-5 -0.00014520751089948882 4.418713554116308e-5 -0.00016831704367575536 -9.42788635024408e-5 0.00014792603663011037 0.00012892739080200662 0.00013377249463296883 0.00016895678466027393 7.096972451620494e-5 7.926473092594475e-5 3.336830316185368e-5 0.00017995440914602567 0.0001420016900061699 1.8269592442768933e-5 -0.00017072193684213965 1.1915585969857025e-5 -5.2059908443327756e-5 -9.820382062969273e-5 8.367122548009116e-5 5.211498621576499e-5 5.9977731621596346e-5 -1.4192731827651333e-5 -9.050437317854038e-5 8.306768479599412e-5; -0.00019107184429202834 4.447921806813926e-5 -2.127690326205781e-5 -2.5733389101964996e-5 -1.0475763835503613e-6 8.497337567240054e-6 -0.00017555608295851335 0.00010323330703640899 -4.2316549788377245e-5 -0.00010646249806058706 0.00014762294946222282 -0.0002056205144432699 -3.33036937624424e-5 -0.00013462563450930028 8.500501699333317e-6 7.666523252108386e-6 -4.463444796756228e-5 -4.5710172834959805e-5 0.00013088155306860305 8.218642820821458e-5 -2.0972766414795487e-5 1.7600830197212355e-5 -5.636746804129693e-5 6.678815903824674e-5 -1.7498864089871874e-5 1.9208440299102177e-5 3.561571198615805e-5 -0.00017748443000500025 -4.217553081586279e-5 7.13901731254042e-5 -2.4807416175182446e-5 -4.867173132852059e-5; 2.9350580841595702e-5 1.2967373474843789e-5 -0.00013705671797897316 -6.617123206769282e-5 4.667978366096928e-5 -3.2755665490797195e-5 6.149208722245882e-5 -0.00016187678473461164 2.2364673820744215e-5 3.957407981698147e-5 -3.8022858536984306e-5 7.516877023350693e-5 -0.00010175535779371338 -5.6396452312307495e-5 1.320329731003836e-5 6.978499088046384e-5 -4.089993952465811e-5 -5.306603550951295e-5 0.00016774327461218112 -0.0001507387197149784 -0.00012849515686202436 -0.00012396849259187575 -4.008800450055559e-5 0.00013407204907462723 -0.00020112084083182628 -0.00020751227483078262 9.101331753963886e-5 -0.00011865209588239876 0.00022073784286235533 -0.00017107849915434087 -4.7941000123216806e-5 -0.00015487834669574763; 2.6035871500969478e-5 -0.00019728663077822386 3.0431368445162062e-5 4.706183476258949e-6 0.00012734081197285542 0.00015395710893663418 -0.00011012707595491969 6.148759662627592e-5 -0.00011764858258665788 9.817616459798438e-6 -3.155562857992502e-5 -4.4316188491732965e-5 5.423733041842972e-5 -0.00014618589243164801 -9.517452467241037e-5 3.4742326037846054e-5 3.06714313906845e-5 8.91536178229209e-5 -0.00012393346746230347 -5.8153382765706145e-5 2.0970364361884452e-5 -9.00272066659178e-5 -9.549766714743779e-6 -3.5348285067868804e-6 -9.753130193111993e-5 2.2629011668498412e-5 -0.00020262059164943735 -5.236424162413164e-5 -0.00011520865568434642 4.465272964317627e-5 0.00014881456395832794 2.9520262118777474e-5; 0.000152306976421098 2.6964323811992402e-5 -0.00010536289158480665 -7.212099120224345e-5 0.00017887728933274318 -0.00019535296126705227 5.65368569063163e-5 -0.00015461307014774948 -3.068053450406152e-5 0.00018808709461736977 -0.00012573878160182815 -0.0001309381809129271 -5.037100830665699e-6 2.5973440248347316e-5 -0.0001744973455366574 3.175710623181881e-5 -2.46509993608897e-5 0.000124196939911055 1.7228810580856785e-5 -0.00011667209440347294 -1.3877875456816555e-5 0.00012117433977109498 0.0002277546987592387 1.9097405644427245e-5 -0.0001274162827336096 5.666627800237882e-5 0.00012424788616627037 2.9740232609040146e-5 5.598292006325204e-5 -7.006810886523368e-5 -3.81898102510582e-5 1.3198860475679841e-5; -0.00016249895496555603 -3.1445767022961294e-5 -3.771025375143082e-5 -1.91266758393612e-5 3.5385737204631915e-5 -9.924311096808385e-5 -3.660123046400184e-5 -9.076712238864469e-5 -7.249135750141498e-5 3.533735572447739e-5 -0.00016450115298183068 9.285479383283215e-5 -0.00015069681329449032 -1.4367973095506077e-5 -0.00012104637652196735 0.00010373252508900064 2.8407930712258625e-6 7.94547805883081e-6 3.316342328064214e-5 -0.00012307225778029305 -5.259676844260109e-5 1.40028055411173e-5 -6.035622799216248e-5 3.891738152917234e-5 0.00012527775608207352 0.00012892283622379112 6.933686883159495e-5 7.927740581537448e-5 -7.842409330447719e-5 7.204721397459357e-5 0.00012982282852492242 -5.562417071458349e-5; -8.238643679270378e-5 -9.180172967227029e-5 -0.00023732115116637052 -4.466478662480079e-6 1.2434555879180861e-5 -5.5137277035655976e-5 5.447164110172684e-5 7.206398984216302e-5 0.00010630658905519309 -9.10076171063431e-5 -3.428375678105132e-5 -0.00013665370827276144 -5.0395077073162815e-5 0.00017883763464432358 4.974636137866525e-5 1.4724801307145117e-6 8.153371416064128e-5 0.0001625052634839355 -1.6412826175483398e-5 -3.732902152897648e-5 9.162778747295109e-6 3.1936119535557167e-5 1.053333815015302e-5 -3.0751825056186346e-5 4.400159630257775e-5 0.00012857667277366069 -0.00011185636433275268 -3.227657837179345e-5 -5.312162306482162e-6 -6.083633090803269e-5 0.00017337589340549105 -3.2308185131669424e-5; 8.589769492543839e-6 9.804282257115569e-5 -3.853847061967421e-5 -6.475253081865436e-5 -0.00012736202967338184 3.4650155948577414e-6 0.00010525882629459941 -4.1480533184800004e-5 8.177984221853949e-5 -0.00015544286048956068 -0.00012600449605364272 -4.872458572482965e-5 -6.124358663360023e-6 8.667249443257286e-5 -9.93996077913986e-5 0.00010555946886321814 -2.0396634284257288e-5 -5.270140587802582e-5 -7.348917330367111e-5 -0.0001376792502588001 5.738792918870284e-5 -5.0700128270379465e-5 0.0001305587788346787 -4.495114677687015e-5 0.00011263029540438715 -1.364864016344186e-5 -4.567971751661003e-5 1.571822343700325e-5 5.224545333201414e-5 5.0844063877751555e-5 -3.26822633387864e-5 -1.7933580750682817e-5; 0.00013921876405718412 -0.00015577480927688047 -6.850668900840973e-6 -3.3375794477724383e-5 7.970332979810483e-5 -8.095941131116726e-5 3.745014499643663e-5 4.36374956602616e-5 3.575578910883674e-5 0.00010316800206554076 3.941476632961131e-5 -8.579026247883321e-6 9.985027541443879e-6 6.237632795590058e-5 -5.7916995605907056e-5 -3.582679174953396e-5 -0.0001487710743408735 7.777163581111049e-5 4.548405002922357e-5 0.0001033136376331435 -9.126199266985905e-5 2.8278867726343496e-5 -0.0001670092225272291 0.00019269037319234503 -8.205675034810615e-6 9.913379643465388e-5 -9.774225883531267e-5 0.00020677629243935294 -5.929176233124796e-5 2.3476481042478878e-5 5.8685540248441616e-5 -3.8897588770105195e-6; 4.774877578165626e-5 9.520094880282137e-5 1.126072146145051e-5 -0.00012252544712827623 0.00011983819984744582 3.577738078455326e-5 0.00012217599413270804 8.298279081033871e-7 0.00010412017416819042 0.0001427246226594505 -1.5691382774083815e-5 -5.835144640154966e-5 9.82510520625598e-5 -5.7819151894353645e-5 8.41076964060815e-5 0.00010987127296539445 5.702792821831412e-5 -6.914305518878279e-6 -0.00015956177395855942 0.0001004615461934011 4.525976859451014e-6 7.244913848257187e-5 7.101636418520884e-5 6.0740485590947414e-5 1.4536745933576495e-5 6.751654126129232e-5 4.215705319803638e-5 -7.412868018581813e-5 1.2954344296268458e-5 0.0001547380290357772 -7.821484343003092e-5 3.1391764590818486e-5; 0.00011686651452555315 3.1425428113386137e-5 0.00016147851021606444 -2.244975110579224e-5 0.00011101051097148932 2.7632642394097456e-5 0.00015894796126167656 -1.4028402239537699e-5 0.00012562558878788672 -0.00017123934629261826 -5.149310925980191e-5 0.00026991687956838485 2.3489366126402895e-6 0.00011192073326902421 -1.373008797741332e-5 -1.9034042134062966e-7 3.410496554814325e-6 -0.0001294319121202568 -9.811404414131466e-5 -4.375881360965776e-5 3.469282784727363e-5 2.1141458654208298e-5 -1.794289926828335e-5 0.000191576352898214 -6.947935287963249e-5 -0.00013418522247001766 0.0001803664596250284 5.7427998052391916e-5 2.8150919768872373e-5 1.058295185291967e-5 -7.040960446374928e-6 6.224080026959988e-5; 9.136531497107772e-5 1.1279599844886363e-5 3.475632382661588e-5 2.6897303711042738e-5 -0.00010704935795173466 -0.00013932559323973502 -4.3366666262850424e-5 9.911872819985711e-5 -7.812285713396581e-5 -6.918974196549819e-5 -3.602554956057689e-6 -7.410949712162936e-5 -0.0001266572342627906 9.729142965271629e-5 -3.0731591153981304e-5 -1.296624015038482e-5 -9.444690459885144e-5 -4.244595567241269e-5 1.3374883734393142e-5 1.3279254004479802e-5 -0.0001453852581233948 0.00021653411584897098 5.08244649482658e-5 0.00015191821347444953 0.00016467071162525772 -6.675388959946465e-5 -1.4839628251036714e-5 -2.380470852029712e-5 -8.346584019261024e-5 -0.00012618813872347767 2.57537359958551e-5 -0.000141721842568484; 0.00010958985034802824 -2.493471152420489e-5 5.7843920098001006e-5 6.6499794628672534e-6 4.797733594741413e-5 7.347831228242051e-5 0.00019201413919644672 1.271823907145414e-5 -0.00013187556510831314 -8.946351341010547e-5 -9.689257900036197e-5 0.0002186353983632644 9.089648187354797e-5 0.00010381017886529466 6.439153616505863e-5 -9.534345968959011e-6 -5.386616726909973e-5 3.611080568942882e-5 0.00018722807248549293 -2.5828024496250806e-5 4.357695683565715e-5 1.6226015526482728e-5 -9.879799222844808e-5 5.560810927262945e-5 0.00010153224388330539 1.0374074685292616e-5 8.334270882095206e-6 1.7619486919913795e-5 4.2481681470105534e-5 -0.00016588515177875664 -9.271411050676734e-6 9.720793854639683e-5; -4.913032179112513e-5 -0.00012941810208368356 3.447080226968852e-5 -0.00013586430949154754 -2.94612910563144e-5 2.1876747190863877e-5 -5.52387890430976e-6 8.603569689776756e-6 9.754846897910475e-5 4.6078021036270966e-5 4.6626333564099055e-5 7.294097660109641e-5 3.505495616472073e-5 -8.252686173883399e-5 -4.791626641759939e-5 -6.380951170805676e-5 -1.9027015338999555e-6 3.18464844115437e-5 3.4260119641003146e-5 -1.5496903740482326e-6 7.911594358510208e-5 0.0001343844886962303 7.615037240092563e-5 0.0003527685002284549 -1.2819854496887355e-5 4.4403812275311725e-5 -0.00024189493350214392 0.00029525179335385896 4.2876101282901913e-5 9.675090033330377e-5 -7.955073314016545e-5 0.00014243916295007697], bias = [-2.2133104277599905e-9, -1.709821723890786e-9, 1.4352555322729281e-11, -2.2423325107801296e-9, 4.445443224355152e-10, -5.535227273059002e-10, -4.2059331012156634e-9, 1.4517032641723618e-9, -6.926979819290503e-10, 4.2598350663478234e-10, 1.0985551509752595e-9, -2.3847990842564265e-9, 1.000924785459116e-10, 8.127125641191906e-10, 3.2996472021557102e-9, -7.584516682295364e-10, -1.387167435078597e-9, 1.5710530804566214e-9, 4.4067979871360365e-9, -2.6071026680856624e-9, -3.942734745566996e-9, -2.314250568164728e-9, 2.5790214122345366e-10, -1.5884198384268388e-9, 2.4352075154134458e-11, -1.106457635466508e-9, 1.6533572146423857e-9, 4.167472082109472e-9, 3.619441949399424e-9, -1.5841705028014582e-9, 3.3348841487782253e-9, 3.0958420126804027e-9]), layer_4 = (weight = [-0.0005893760151152129 -0.0005860714096631461 -0.000741374782809004 -0.000605284915855523 -0.0005773993446138456 -0.0008438795283904732 -0.0008239465141780389 -0.0008121379162152775 -0.0007350899189486165 -0.0008957424989827458 -0.0004799786594224046 -0.000669993314491262 -0.0007437099174958013 -0.000810442987611029 -0.0006687746100767417 -0.0009314047653542297 -0.0006434170804564896 -0.0006716618512541457 -0.0005674424521008677 -0.0006562524050441192 -0.0007718874801474663 -0.0006568769856327636 -0.0004995041241921094 -0.0005974355275492605 -0.0007209748748466796 -0.0006846526218921775 -0.0006725545183410523 -0.0006616622428087592 -0.0005571310698370337 -0.0008568168099736145 -0.0008015000647591991 -0.0006712003446673714; 0.0003873435458205753 0.00020595554825212888 0.00024220652962518837 0.0002989046076246199 0.00033709685246147267 0.00019160499853693278 0.00018216385560573174 9.570962331452566e-5 0.0004162277057632122 0.0003256419121666149 0.0002765429606406636 0.00014307049954532207 0.0002557025841801893 0.00017531624385999355 0.0002539739980346544 0.00038387016779728726 0.00020570742541710762 0.00016139942582153516 0.0002909471832474673 6.375304397134565e-5 0.00024134980921489534 0.0003330618272553611 4.2977530134993194e-5 -7.539131722098868e-5 0.0002585794869835237 0.00020448408032104756 0.00020616098753456337 0.00027130246573540024 0.00037807996662542247 0.0002141939688965494 0.00017673426752328082 0.0003180552763281201], bias = [-0.0006856571583228618, 0.00023262149502261293]))Visualizing the Results
Let us now plot the loss over time
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")
lines!(ax, losses; linewidth=4, alpha=0.75)
scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)
fig
endFinally let us visualize the results
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
s3 = scatter!(
ax,
tsteps,
waveform_nn_trained;
marker=:circle,
alpha=0.5,
strokewidth=2,
markersize=12,
)
axislegend(
ax,
[[l1, s1], [l2, s2], [l3, s3]],
["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
position=:lb,
)
fig
endAppendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.11.8
Commit cf1da5e20e3 (2025-11-06 17:49 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 4 × AMD EPYC 7763 64-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 4 default, 0 interactive, 2 GC (on 4 virtual cores)
Environment:
JULIA_DEBUG = Literate
LD_LIBRARY_PATH =
JULIA_NUM_THREADS = 4
JULIA_CPU_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0This page was generated using Literate.jl.