Training a Neural ODE to Model Gravitational Waveforms
This code is adapted from Astroinformatics/ScientificMachineLearning
The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl
Package Imports
using Lux,
ComponentArrays,
LineSearches,
OrdinaryDiffEqLowOrderRK,
Optimization,
OptimizationOptimJL,
Printf,
Random,
SciMLSensitivity
using CairoMakieDefine some Utility Functions
Tip
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector
function one2two(path, m₁, m₂)
M = m₁ + m₂
r₁ = m₂ / M .* path
r₂ = -m₁ / M .* path
return r₁, r₂
endNext we define a function to perform the change of variables:
@views function soln2orbit(soln, model_params=nothing)
@assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
if size(soln, 1) == 2
χ = soln[1, :]
ϕ = soln[2, :]
@assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
p, M, e = model_params
else
χ = soln[1, :]
ϕ = soln[2, :]
p = soln[3, :]
e = soln[4, :]
end
r = p ./ (1 .+ e .* cos.(χ))
x = r .* cos.(ϕ)
y = r .* sin.(ϕ)
orbit = vcat(x', y')
return orbit
endThis function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d_dt(v::AbstractVector, dt)
a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
b = (v[3:end] .- v[1:(end - 2)]) / 2
c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
return [a; b; c] / dt
endThis function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d2_dt2(v::AbstractVector, dt)
a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
return [a; b; c] / (dt^2)
endNow we define a function to compute the trace-free moment tensor from the orbit
function orbit2tensor(orbit, component, mass=1)
x = orbit[1, :]
y = orbit[2, :]
Ixx = x .^ 2
Iyy = y .^ 2
Ixy = x .* y
trace = Ixx .+ Iyy
if component[1] == 1 && component[2] == 1
tmp = Ixx .- trace ./ 3
elseif component[1] == 2 && component[2] == 2
tmp = Iyy .- trace ./ 3
else
tmp = Ixy
end
return mass .* tmp
end
function h_22_quadrupole_components(dt, orbit, component, mass=1)
mtensor = orbit2tensor(orbit, component, mass)
mtensor_ddot = d2_dt2(mtensor, dt)
return 2 * mtensor_ddot
end
function h_22_quadrupole(dt, orbit, mass=1)
h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
return h11, h12, h22
end
function h_22_strain_one_body(dt::T, orbit) where {T}
h11, h12, h22 = h_22_quadrupole(dt, orbit)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
h11 = h11_1 + h11_2
h12 = h12_1 + h12_2
h22 = h22_1 + h22_2
return h11, h12, h22
end
function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
# compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2
@assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"
h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
@assert mass_ratio ≤ 1 "mass_ratio must be <= 1"
@assert mass_ratio ≥ 0 "mass_ratio must be non-negative"
orbit = soln2orbit(soln, model_params)
if mass_ratio > 0
m₂ = inv(T(1) + mass_ratio)
m₁ = mass_ratio * m₂
orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
return waveform
endSimulating the True Model
RelativisticOrbitModel defines system of odes which describes motion of point like particle in Schwarzschild background, uses
where,
function RelativisticOrbitModel(u, (p, M, e), t)
χ, ϕ = u
numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
denom = sqrt((p - 2)^2 - 4 * e^2)
χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = numer / (M * (p^(3 / 2)) * denom)
return [χ̇, ϕ̇]
end
mass_ratio = 0.0 # test particle
u0 = Float64[π, 0.0] # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4) # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize) # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, eLet's simulate the true model and plot the results using OrdinaryDiffEq.jl
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)
axislegend(ax, [[l, s]], ["Waveform Data"])
fig
endDefining a Neural Network Model
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but in case you do use them, make sure to mark them as const.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,
const nn = Chain(
Base.Fix1(fast_activation, cos),
Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-8.8915265f-5; -9.6265816f-5; -0.00012949806; -8.801643f-5; -2.0153757f-5; -0.0002156531; -2.2652956f-5; -0.00013827362; -1.7835846f-5; 9.843288f-5; 2.4555661f-6; 5.831199f-5; -3.6028323f-5; 4.5859168f-5; 3.1441843f-5; 2.4407002f-5; -0.000104849525; 0.00017740538; -9.148711f-5; -0.00011924671; -0.00012321919; 8.6015105f-5; -8.8653156f-5; -7.072169f-5; 8.5449356f-5; 7.6370736f-5; 5.56605f-5; -7.1633905f-5; 9.260593f-5; -2.2408954f-5; 2.595824f-5; 5.5816796f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-6.056251f-5 5.3457545f-5 0.00016091349 2.4134446f-5 8.1149294f-5 8.0325466f-5 -5.3063246f-5 0.00026293742 -1.5516911f-5 8.2372084f-5 -0.000113703405 8.210816f-5 -3.111872f-5 -0.00020261883 -0.00010638544 -0.000107577085 7.5371536f-5 -5.5334825f-5 -1.8100787f-5 3.6576777f-5 -0.00020493104 5.309867f-5 0.00027227055 -4.4363343f-5 9.5104886f-5 -5.0723568f-8 8.521526f-5 -0.00014991967 -0.000114302136 0.00017970677 -0.00027880337 3.88955f-5; 6.0334052f-5 -3.9870145f-5 -9.9084005f-5 -2.2236092f-5 7.330838f-5 -2.8403536f-5 -0.00011742177 -0.0001137422 2.730496f-5 8.366315f-5 0.00010053122 -0.00015309351 7.553822f-5 -4.9288046f-5 -0.00012524592 0.00010918356 8.392876f-5 4.346021f-5 -5.4301774f-5 -0.00015768278 7.527691f-5 1.6847482f-5 0.00011905832 -7.8696976f-5 -0.00015681783 -8.461307f-5 0.00018905658 1.2294079f-5 0.00010818924 -9.32627f-5 7.746915f-5 -1.0636427f-5; -1.7688373f-5 5.9254897f-5 -1.5314674f-5 0.00013281322 -0.00010800169 -6.614393f-5 5.6622892f-5 1.1431871f-5 8.129055f-5 -8.1925835f-5 3.5425677f-5 5.238411f-5 -0.00015661615 3.9833947f-5 -0.00014499339 -8.5514126f-5 -9.758446f-5 2.8045548f-5 -9.8452205f-5 -9.4961666f-5 1.7721337f-5 0.000103242666 -9.5383286f-5 -3.251011f-5 0.00016126434 -7.731466f-5 2.236857f-5 -0.00010275267 -0.00018648751 -3.4275516f-5 6.750124f-5 -0.000120274846; 8.186451f-5 -5.8016517f-6 -5.7443864f-5 -4.0082927f-5 0.00010130929 -2.312017f-5 0.00013606835 -2.8461187f-5 2.9017372f-5 -0.00019941531 -0.000109736306 -1.6373813f-5 4.4642602f-5 4.2207364f-5 1.5857962f-5 -5.196456f-5 0.00012225042 1.8400882f-5 -0.00023584883 -0.00010600877 -0.000129738 2.3410878f-6 5.131859f-5 -0.00016858317 -1.094441f-5 -6.333613f-5 -8.8967296f-5 1.2109975f-5 0.0001178615 4.3449618f-5 -2.7913558f-5 6.93078f-5; 2.8033146f-5 7.791823f-5 0.00016608149 -5.2466166f-5 5.8139005f-5 -8.0791004f-5 -6.5888125f-5 -4.997893f-5 -0.00015309078 8.182305f-5 -0.00010169868 0.00016863667 5.4431275f-5 -0.00011236579 -0.0001003315 8.661164f-5 1.330218f-5 0.00021071492 0.0001753717 3.813241f-5 4.8002497f-5 -4.1361032f-5 9.1684335f-5 -0.00016690507 -3.3047407f-5 0.00010841587 -0.00010967311 0.00017800019 7.99587f-5 -0.00014051555 -0.00010580397 3.143387f-5; -0.000187216 -5.2521435f-5 -5.8545153f-5 7.80235f-5 -7.001475f-5 -1.2909377f-5 5.5979606f-5 9.500527f-5 8.4030544f-5 -5.1437396f-6 8.182988f-5 1.1239474f-5 6.94826f-5 -1.1376742f-5 2.4269186f-5 -8.25536f-5 5.3925028f-6 0.00017558175 -9.284131f-8 -5.514777f-5 8.433125f-5 -0.00015324468 -4.2609914f-5 -9.1006776f-5 -0.00014455791 -3.6185244f-5 -0.0002030065 -9.311817f-5 -8.267748f-5 -2.76321f-5 0.00010075266 -5.7782232f-5; -6.467965f-5 0.00023088377 -4.4165652f-5 0.00021037782 -6.6909546f-5 9.834028f-6 0.00013173124 7.736485f-5 2.0738267f-5 -8.959815f-5 -6.238815f-5 3.995518f-5 -2.3252589f-5 -7.780762f-5 8.184617f-5 2.3256f-5 -7.45022f-5 -0.00014840325 -3.334281f-5 0.00016967277 8.5104104f-5 8.466957f-6 -8.4854095f-5 -8.224913f-5 -0.0001727673 2.8423365f-6 -4.800741f-6 -7.964143f-5 -7.678097f-5 -0.000110360656 -3.6295813f-5 7.125732f-5; -4.23211f-5 -0.00020053994 5.3836593f-5 -0.0001245354 -4.5564448f-5 -2.28038f-5 9.836009f-6 7.733122f-5 2.5114454f-5 1.2144471f-5 7.002792f-6 -7.5188436f-5 2.0882533f-5 7.775267f-5 0.00024124263 5.547853f-5 -8.251284f-5 -0.000102589314 -7.1460156f-5 9.057194f-6 0.00013611936 4.458467f-5 -4.021292f-5 -1.2661196f-5 -1.8434242f-6 -6.0880495f-5 1.1976202f-5 -2.1664622f-5 -0.00022910666 0.00012493492 -0.0002026236 -0.00026658468; -5.548841f-5 0.00012804483 7.841857f-5 5.873229f-5 -4.3569567f-6 -0.00023105612 1.908292f-5 7.275275f-5 -8.2622006f-5 -0.00016623572 0.0001644236 -0.00011853807 -7.963172f-5 -2.3004866f-5 -5.146604f-5 0.00015801414 -4.2501044f-5 3.959334f-5 2.734249f-5 -0.00012313723 -0.00015830512 -1.1559633f-5 8.468042f-5 5.9422728f-5 4.8592126f-5 -7.5375436f-5 -4.5701585f-5 8.947265f-6 5.1005907f-5 0.00016106854 0.00011125841 -8.6434f-5; -9.046996f-7 -9.175775f-5 -5.1525208f-6 0.00012092941 -3.0335392f-5 3.1075465f-6 9.651811f-5 0.0002693987 -0.00018144662 8.391303f-5 5.5525885f-5 -1.739893f-5 4.4818607f-5 0.00015967998 -0.00013328048 2.4780962f-5 -2.5041769f-5 7.334421f-5 9.600632f-5 0.000113070826 -3.5066332f-5 -4.3224874f-5 -1.9799076f-5 9.198721f-5 0.0001040842 0.00020631636 3.803934f-5 -2.4899951f-5 3.766592f-5 3.2835185f-5 -0.00012204292 2.7492766f-5; 2.885983f-5 5.647442f-5 -0.00015630298 0.00013754478 -0.00015074637 -0.00014203988 -2.2119522f-5 0.00010210608 -5.9326834f-5 -0.00015147644 2.8741024f-5 -9.941839f-5 0.00013675602 -0.00011701144 7.702124f-5 -3.160233f-5 -0.00014670244 3.0696672f-5 5.4862403f-7 0.00015131882 -0.00012398392 1.4395294f-5 1.3728702f-5 9.5729185f-5 -1.4207627f-5 0.00010291316 -9.904216f-6 1.9178982f-5 -4.3012362f-5 0.00017355049 0.0001130138 1.748535f-5; 3.803159f-5 -1.9191843f-5 -2.2461953f-5 -0.00016496018 0.00012105661 1.6808777f-5 4.1296473f-5 -3.0883362f-5 -0.00010366525 0.00011644885 1.7903552f-5 -1.6499764f-5 3.1213272f-5 -0.00012148892 -3.254927f-5 4.529318f-5 5.056621f-5 0.00012815256 0.00018411834 3.127899f-5 0.00014885922 1.1625094f-5 5.800957f-5 3.4804667f-5 -0.000109441164 0.00012662853 -6.208696f-5 -9.7527445f-5 3.033192f-5 -8.0465055f-5 -0.0002725422 -5.7855126f-5; 2.4424811f-5 -0.000114543815 3.370122f-5 6.957754f-5 2.2636492f-5 0.00013090152 -1.2547328f-5 -0.00022660497 -3.2499465f-5 4.9118657f-6 7.819537f-5 -6.631204f-5 0.00011070473 -2.4906776f-5 -6.207029f-5 -8.8092005f-5 1.3297287f-6 -2.7879247f-5 -3.994083f-5 2.9449611f-5 -9.629269f-6 -3.111421f-5 9.29688f-5 0.00020564764 2.9892986f-5 -6.629923f-5 -7.8544625f-5 -4.97935f-5 0.00026567874 6.584121f-5 -4.055575f-6 0.00017599689; 0.00016610316 -0.00011529264 2.2278978f-6 0.00024999215 0.00019909577 0.00015410359 4.870953f-5 5.915784f-6 -6.474182f-5 7.0797134f-5 4.8109905f-5 0.00016185765 1.5979149f-5 -0.00014346009 1.6206995f-5 -2.5503101f-5 -9.4415256f-5 5.5323275f-5 6.3475265f-5 2.6959742f-5 -3.1152977f-5 -4.660535f-5 7.088729f-5 -5.0975417f-5 -1.1338998f-5 3.8688242f-5 -0.00014194459 -4.933175f-5 6.416656f-5 7.541669f-5 4.987582f-5 -0.00010965505; -0.00022470068 -7.9923484f-5 -0.00015109428 3.4076907f-5 3.2358314f-6 -4.6431655f-6 6.703966f-5 -4.2079926f-5 2.90432f-7 4.438677f-5 -0.00013550374 -4.718559f-5 0.00016251358 -8.2043625f-5 -2.6227055f-5 8.7266395f-5 -8.153693f-5 0.00016258996 0.00010053786 3.8081045f-5 7.910016f-5 0.00014494944 -1.6564074f-5 0.00020702808 -7.925179f-5 -2.267569f-5 0.00013224421 -2.2564716f-5 1.9990728f-5 8.063634f-5 -0.00012483144 6.849232f-5; -3.3169406f-5 8.418706f-6 -3.3823753f-6 -7.374794f-5 -5.338565f-5 0.00018493611 -7.360977f-5 5.3407484f-6 0.00015228402 0.0002028485 1.4077922f-5 -6.55989f-5 -1.311591f-5 9.5784286f-5 7.667825f-5 -2.0362151f-5 -7.3929165f-5 0.0001235654 -2.6209054f-5 0.0001635526 0.0001312094 0.00014993116 -0.0001466645 6.742799f-5 0.00017454114 4.6003475f-5 -0.00014373982 -5.839347f-5 -0.00012072421 0.00015888044 0.00013382165 -1.8607447f-5; -0.000107355925 -8.607889f-5 0.00012149916 9.585516f-5 5.562724f-5 -0.00019318206 3.7993465f-5 -2.2566033f-5 3.0535022f-5 2.7782164f-5 8.723057f-5 9.7498436f-5 -2.7319773f-5 6.182337f-5 9.96632f-5 3.7555987f-5 3.4654637f-5 1.7276001f-5 -8.269291f-5 -6.42737f-5 -3.114173f-5 3.6994059f-6 0.00015694385 -0.00015989404 -0.00015704743 0.00011026388 -9.139098f-5 -0.00017994767 9.747675f-6 -4.1182953f-5 -5.9712216f-5 9.566006f-5; 1.6924797f-7 0.00025378575 -0.00013553043 -0.00013438355 -7.333486f-5 -5.6306497f-5 0.00010499927 -0.0002021184 3.6227834f-6 -4.519982f-5 0.0001146496 0.000100452824 -6.980514f-5 -0.00013293364 -0.00023491589 0.00013674464 -6.2428786f-5 0.0001285128 -6.296377f-5 2.0308371f-5 -3.1947144f-5 5.1006864f-5 -0.00011716416 4.032473f-6 0.00013381607 9.3557544f-5 5.6698143f-5 -0.00010606786 -7.167248f-5 -7.4547585f-5 -0.00010068391 0.000112814; 0.00012063817 -8.751235f-5 9.72957f-5 -1.5679261f-5 1.4381618f-5 0.00013394516 3.6540045f-5 4.135253f-5 3.8867776f-5 5.7616584f-5 -0.000108938795 -0.00010166722 -0.00011184273 0.00014328359 0.000109988425 -3.1824453f-5 -6.589847f-5 -5.4398224f-5 -0.00017567976 4.2107396f-5 0.00012888073 -7.421487f-5 -4.6915324f-5 2.8323973f-6 0.000105741165 4.0331117f-5 -4.730636f-5 3.6048892f-5 -8.619665f-5 0.00010764717 -9.3132156f-5 0.00010064516; 3.8945174f-5 -2.2936833f-6 0.00014723495 5.68572f-5 -0.00016622305 -1.673823f-5 2.1661623f-5 -0.00015275583 -6.224986f-5 2.5103704f-5 -1.0769967f-5 -5.21992f-5 1.2495653f-6 -0.000101802936 -0.000165851 -0.00012791614 -2.2503202f-5 -4.4841063f-6 -0.00010401725 -9.778565f-5 -6.8153066f-5 2.7580072f-5 -5.7685616f-6 6.225862f-5 -0.00019034327 -4.3543852f-5 -2.1784434f-5 0.00014186854 -6.113162f-5 -5.8805905f-5 -0.00012999796 0.00015724282; -8.00086f-5 3.544768f-5 -0.00011183314 -0.00016493548 6.749685f-5 5.498128f-5 4.1381318f-6 -1.30272f-5 -1.691521f-5 -3.4878452f-5 1.8443372f-5 -0.00012870962 3.1273383f-5 0.00010086395 5.5960852f-5 -6.011614f-5 -7.091628f-5 0.000100803365 0.0002269973 2.8328902f-6 4.8624086f-5 0.00017089135 -4.348911f-6 -8.013109f-5 5.5180717f-6 -3.200978f-5 5.4567146f-5 7.815284f-5 -0.00015617411 -0.0001061558 -0.00019061778 -6.8618894f-5; 0.000101626625 -0.00016719713 8.569818f-5 -5.558202f-5 -2.15394f-5 4.8892354f-5 -2.5873784f-5 2.8590657f-5 2.0070727f-5 -3.3890276f-6 1.4975326f-5 -3.7693782f-5 -7.6235174f-6 0.00017076069 0.00019831296 -7.3308074f-5 6.064777f-5 -4.6035486f-5 -7.370111f-5 0.0001457052 7.702005f-5 8.861587f-5 -3.323366f-5 -1.6152033f-5 -3.279736f-5 5.6920882f-5 4.7634556f-5 0.00020827906 -4.148211f-6 -6.0836195f-5 6.39705f-5 -9.1393114f-5; -0.000111166 0.00010513608 0.00011285285 -3.040426f-8 4.7227088f-5 4.140097f-5 -5.769202f-5 -0.00011897626 7.844662f-5 -5.611753f-5 -1.5753123f-5 -3.26639f-5 0.00014873383 -0.00014093294 0.00016771043 -9.930882f-5 1.6234171f-5 -0.0001781381 -0.00011105378 -0.00014650721 -8.963978f-5 1.6580145f-5 -0.0001353262 -6.581112f-5 -6.2743144f-5 5.2253294f-5 1.7963357f-5 -0.00011278885 0.000110292814 -0.00011297213 0.00011059856 0.00018377157; 0.00013341317 -9.3397866f-5 7.427955f-5 -0.00021352527 -0.00013579272 -2.7562688f-5 7.092179f-5 -1.7185943f-5 -3.387787f-6 -5.9925027f-5 6.105416f-5 4.5035493f-5 9.493584f-5 7.125668f-5 -0.00012371443 -0.00010645253 2.3866223f-5 -3.263969f-5 -8.1273756f-5 -6.51258f-5 -2.8719069f-5 7.114998f-5 3.424658f-5 -9.669043f-5 -6.444645f-5 -7.716474f-5 8.997803f-6 -3.755644f-7 2.1168764f-5 -3.60344f-5 -8.678021f-6 -4.50861f-5; -5.9666116f-5 -0.00018533276 7.8407014f-5 -8.0834696f-5 -5.8524336f-5 -0.000110462686 0.00010860594 8.376145f-5 -8.785732f-5 -0.00020260327 0.0003467606 -1.9261335f-5 -0.00013966025 -3.3708333f-5 2.1366748f-5 -4.2759002f-5 -6.865495f-5 -3.467193f-5 5.541349f-5 6.7042265f-7 1.5427833f-5 -7.091958f-5 -0.000162441 -7.449017f-5 9.175909f-5 5.0602353f-5 -4.6617057f-5 3.2259133f-5 3.7648988f-5 5.55581f-6 -2.2591812f-5 4.4970322f-5; -0.00020058607 -6.449334f-6 -3.67323f-5 -0.00014165934 0.00013014472 0.00018787287 -1.7228871f-5 0.00013599594 -0.00016266441 -4.1922794f-6 -0.00015345607 0.000107351436 2.0680678f-5 6.125564f-5 2.6016774f-5 -0.00013112843 -0.00010157658 -0.00011807096 6.0145354f-5 8.215798f-6 9.0576075f-5 -7.899245f-5 -4.5886834f-5 4.9050956f-5 1.5268419f-5 0.00010661592 2.949581f-6 8.102905f-6 -4.4800752f-5 -7.791474f-5 0.0001469789 -9.452163f-5; 4.823068f-5 -0.00014505514 0.00012577002 -5.339856f-5 -0.00013106868 -0.00015173006 -3.0740277f-5 9.996116f-5 -4.2713902f-5 1.9340714f-5 -5.7415796f-6 -7.3743184f-5 -0.00020081917 6.962074f-5 2.9132001f-5 -2.4119234f-5 7.943285f-5 -0.00014018545 -8.893443f-5 0.000121505 -4.4989676f-5 6.5672174f-7 0.00010769937 -3.62389f-6 0.00011422599 -3.927692f-5 6.226391f-5 0.00011233129 6.952302f-5 -0.00013955105 0.00011461146 -3.44124f-5; 0.00022361292 3.596944f-5 -0.00012398708 -1.5435455f-5 -8.898921f-5 0.000103168764 2.0451074f-5 -0.00012059937 -3.8259324f-5 -0.00019358117 -4.979572f-5 -0.00015401983 5.910952f-5 -8.4161475f-5 1.634053f-5 5.2971292f-5 4.116667f-5 3.7763297f-5 -5.4077394f-5 -3.527405f-5 0.00022457157 -8.9463116f-5 -5.692103f-5 -0.00011013869 -8.722587f-5 -7.0986556f-5 -5.6270488f-5 -1.648109f-5 -0.0002132202 0.000115272844 7.8264595f-5 7.292822f-5; 0.00011063703 -7.561902f-5 -0.000103971775 0.00018668755 2.853057f-5 -0.00021859571 4.426497f-6 0.000106792846 -0.000126623 -3.690516f-5 1.9636636f-5 -0.0002242458 -5.93306f-5 -3.680187f-5 -0.00019024026 3.97743f-5 0.00012052482 0.00017755333 1.4125967f-5 5.3715925f-5 7.021795f-5 -0.000113914 4.950059f-5 -4.8192437f-6 -6.822782f-5 -4.310747f-5 0.00010424881 5.900552f-7 7.651495f-5 9.624917f-6 0.00020455051 9.9332916f-5; 8.750619f-5 -8.58246f-5 0.00010468006 -6.4456784f-5 -1.6257334f-5 -0.00021459146 0.00014282855 -4.1796134f-6 0.00011234595 9.1183545f-5 3.8420236f-5 1.4994954f-5 -0.00012757046 1.5510936f-5 0.0001231512 3.3575895f-6 -9.578857f-5 -1.959486f-5 4.598527f-5 -0.000121347206 8.450819f-5 -0.00011940116 -0.00017438835 7.763401f-6 4.9478414f-5 -0.00012758377 -0.000103464896 8.383257f-5 -6.8946116f-5 3.8058282f-5 0.00011218158 -0.000118040385; 0.00013196372 -8.101151f-6 0.00014572423 1.5219099f-5 -1.3444834f-5 8.4206215f-5 -6.9514f-5 -0.00031685378 1.7431212f-5 -7.659907f-5 0.000113385446 -4.76731f-5 -0.000118806805 -0.00011994375 -5.205168f-6 0.0002806604 3.517635f-5 -6.0661005f-5 -1.0633f-5 0.0001435789 -2.5149824f-5 -0.00011427257 5.9354283f-5 8.768726f-6 -6.4465064f-5 -0.00017247748 -5.3954474f-5 -0.00010216108 -0.00021148645 -6.746871f-5 9.383012f-6 1.9787758f-5; -7.127061f-5 1.2896828f-5 0.00020049346 -6.43558f-5 -9.401603f-5 -7.576292f-5 -8.811951f-6 0.00012456342 4.793344f-5 -4.0754425f-5 7.790805f-5 -8.136853f-5 -7.134165f-6 0.00012115491 1.54994f-5 9.83014f-6 5.3736603f-5 0.00015633376 -0.00018541265 -0.00018361832 0.00016610275 5.0876803f-5 -0.00017668874 8.6748616f-5 -6.0961982f-5 -0.00018607105 5.4098706f-5 0.000117590374 -8.1158694f-5 3.2673073f-5 -0.00016728642 0.00011497731], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[-0.00014334514 -1.8708533f-5 3.3741595f-5 1.9931455f-5 -8.598348f-5 -6.8554077f-6 -0.000103900195 0.00012540958 -0.00017757494 -0.00012890721 3.8913797f-5 1.479463f-5 8.791441f-5 1.7800096f-5 -6.3240805f-5 0.00023060251 -2.8496072f-5 4.233342f-6 -1.234489f-6 -4.243917f-5 6.3923826f-6 -9.946762f-5 -6.0673556f-5 3.8248538f-5 2.0238642f-5 5.6737947f-5 -5.7877027f-5 -0.00014475008 -0.00012688986 0.00011227103 0.000115555835 -1.4685401f-5; 0.00012974914 0.00024853394 0.0001970314 -0.00021083282 -0.00011195952 4.3053924f-6 -9.572058f-5 -0.00011776511 -7.443017f-5 -0.00013034196 -1.11894515f-5 4.917172f-5 0.00019243901 8.5792286f-5 -8.330451f-5 6.223397f-6 2.271748f-5 -1.2525237f-5 -1.2211011f-7 4.2285123f-5 -6.3616467f-6 -6.111979f-5 9.237578f-5 0.000108734195 -4.983907f-5 -3.340851f-5 5.9200913f-5 0.00010747817 2.9799761f-5 5.5263736f-5 -2.893236f-5 5.7339206f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))Similar to most deep learning frameworks, Lux defaults to using Float32. However, in this case we need Float64
const params = ComponentArray(f64(ps))
const nn_model = StatefulLuxLayer(nn, nothing, st)StatefulLuxLayer{Val{true}()}(
Chain(
layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
layer_2 = Dense(1 => 32, cos), # 64 parameters
layer_3 = Dense(32 => 32, cos), # 1_056 parameters
layer_4 = Dense(32 => 2), # 66 parameters
),
) # Total: 1_186 parameters,
# plus 0 states.Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses
where,
function ODE_model(u, nn_params, t)
χ, ϕ = u
p, M, e = ode_model_params
# In this example we know that `st` is am empty NamedTuple hence we can safely ignore
# it, however, in general, we should use `st` to store the state of the neural network.
y = 1 .+ nn_model([first(u)], nn_params)
numer = (1 + e * cos(χ))^2
denom = M * (p^(3 / 2))
χ̇ = (numer / denom) * y[1]
ϕ̇ = (numer / denom) * y[2]
return [χ̇, ϕ̇]
endLet us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
axislegend(
ax,
[[l1, s1], [l2, s2]],
["Waveform Data", "Waveform Neural Net (Untrained)"];
position=:lb,
)
fig
endSetting Up for Training the Neural Network
Next, we define the objective (loss) function to be minimized when training the neural differential equations.
const mseloss = MSELoss()
function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
return mseloss(pred_waveform, waveform)
endWarmup the loss function
loss(params)0.0007113733486341368Now let us define a callback function to store the loss over time
const losses = Float64[]
function callback(θ, l)
push!(losses, l)
@printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
return false
endTraining the Neural Network
Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
optprob,
BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
callback,
maxiters=1000,
)retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-8.891526522345888e-5; -9.626581595516803e-5; -0.00012949806114197242; -8.801642979951853e-5; -2.0153756850024268e-5; -0.00021565309725675487; -2.2652955522077554e-5; -0.0001382736227240381; -1.7835845937960477e-5; 9.843287989476622e-5; 2.455566118443988e-6; 5.831199086967757e-5; -3.602832293834349e-5; 4.5859167585111264e-5; 3.144184302072257e-5; 2.440700154692941e-5; -0.00010484952508704368; 0.0001774053816914699; -9.148711251316011e-5; -0.00011924670980057097; -0.0001232191862072678; 8.601510489821682e-5; -8.865315612632716e-5; -7.072168955337592e-5; 8.544935553790337e-5; 7.637073576903671e-5; 5.566050094785027e-5; -7.163390546335818e-5; 9.260592923949649e-5; -2.24089544644596e-5; 2.595823934825068e-5; 5.581679579335688e-5;;], bias = [-2.0350085812319537e-17, 5.2649276742685325e-17, 4.532061310860112e-17, -1.0035249751730721e-16, -2.801117535725461e-17, -2.8645909639868223e-16, 1.676870628273561e-17, -2.1243297055191444e-16, 9.219588414149747e-18, 2.275407824423685e-16, -1.9495592607412648e-18, 2.2017574157725373e-17, -2.253406983471024e-17, 4.2795500235329414e-17, 7.259822677633862e-18, 4.4038070264898445e-18, -4.022497044910005e-18, 1.9124299692041483e-16, -1.310467867018903e-16, -2.4304637112209906e-16, -1.298223557187123e-17, 3.834136276975492e-17, -6.238732257898777e-17, -9.913421882923325e-17, 1.1039515206383633e-16, 1.3723702103644183e-16, 2.7856102857409234e-18, -3.180115290520946e-17, 1.7223880858618611e-16, -2.3458817734725695e-17, 5.654986417973704e-18, 5.0548996270366936e-18]), layer_3 = (weight = [-6.0562116905959326e-5 5.345793784651329e-5 0.00016091388557394292 2.4134839131118753e-5 8.114968682927548e-5 8.032585925244605e-5 -5.306285287663804e-5 0.00026293781434939476 -1.5516517797396474e-5 8.237247698995634e-5 -0.00011370301193967919 8.210855617941788e-5 -3.1118328243638476e-5 -0.00020261843723422404 -0.00010638505022218919 -0.0001075766920083249 7.537192892338482e-5 -5.5334432309641734e-5 -1.8100394064123803e-5 3.657717048478605e-5 -0.0002049306492525218 5.3099062151130347e-5 0.0002722709471155833 -4.436295033959262e-5 9.51052791234982e-5 -5.0330430144822616e-8 8.521565201514835e-5 -0.00014991927286321823 -0.00011430174321579383 0.00017970716022243258 -0.00027880298124803696 3.889589448514828e-5; 6.033357210165811e-5 -3.987062512257126e-5 -9.908448549046055e-5 -2.2236572412080258e-5 7.330789964820132e-5 -2.8404016055863942e-5 -0.00011742224881251467 -0.00011374268063169739 2.7304480746382487e-5 8.366267145087029e-5 0.00010053074226655864 -0.00015309399162012237 7.55377405503474e-5 -4.92885264329436e-5 -0.00012524639890122347 0.00010918307823070201 8.392828028357624e-5 4.345973145158387e-5 -5.4302254219660564e-5 -0.00015768325822952114 7.527643180859377e-5 1.6847001526034256e-5 0.00011905784055764776 -7.86974559823336e-5 -0.00015681830694026336 -8.461355167034124e-5 0.00018905609619744966 1.2293599021208402e-5 0.00010818876041506418 -9.326318097825287e-5 7.746867056178738e-5 -1.0636906998994215e-5; -1.769118974615373e-5 5.925208021373128e-5 -1.531749042904145e-5 0.00013281040370649163 -0.0001080045047593513 -6.614675006407997e-5 5.6620075306368587e-5 1.1429054361473112e-5 8.128773257786161e-5 -8.192865137524292e-5 3.5422860641942405e-5 5.238129521322472e-5 -0.00015661896835362696 3.983113052973441e-5 -0.00014499620814157835 -8.551694266779568e-5 -9.758727613651395e-5 2.804273132363605e-5 -9.84550213934991e-5 -9.496448259149585e-5 1.771852036192262e-5 0.00010323984967263816 -9.538610250737074e-5 -3.251292658668761e-5 0.00016126152255757176 -7.731747685746963e-5 2.2365752706634892e-5 -0.00010275548341732277 -0.000186490325027159 -3.427833220038791e-5 6.749842501895288e-5 -0.00012027766282252958; 8.186268130615857e-5 -5.803483439155393e-6 -5.7445695368143745e-5 -4.008475845963202e-5 0.00010130745975133093 -2.312200196950722e-5 0.00013606652098773113 -2.8463018700631838e-5 2.901554032818005e-5 -0.00019941714528882266 -0.00010973813734650574 -1.6375644939180886e-5 4.4640770262993145e-5 4.2205532715378076e-5 1.585613060982528e-5 -5.196638996764014e-5 0.00012224858981157408 1.8399050511159235e-5 -0.0002358506590307637 -0.0001060105987771327 -0.0001297398248634208 2.3392560497468276e-6 5.131675968827247e-5 -0.00016858499841233382 -1.0946241693962358e-5 -6.33379586569927e-5 -8.896912733958816e-5 1.2108142961588351e-5 0.00011785966669507866 4.344778606267713e-5 -2.7915390112832344e-5 6.930596462283503e-5; 2.803458387697741e-5 7.791966795830154e-5 0.00016608292658256118 -5.246472851231527e-5 5.8140442845435306e-5 -8.078956584306201e-5 -6.588668752319609e-5 -4.997749333361525e-5 -0.00015308933784344652 8.182448523493561e-5 -0.00010169724105016732 0.0001686381119294315 5.443271307067481e-5 -0.00011236435516129752 -0.00010033005951063227 8.661307670318087e-5 1.3303618245660488e-5 0.00021071635724418796 0.0001753731311988286 3.813384769509413e-5 4.800393521927119e-5 -4.135959447685445e-5 9.168577248837962e-5 -0.00016690363104080485 -3.3045968891566607e-5 0.00010841730995944025 -0.00010967166876744095 0.00017800162909461473 7.996013751162345e-5 -0.00014051411112395542 -0.00010580253278694553 3.1435306439198655e-5; -0.0001872182828403482 -5.252371094249408e-5 -5.854742895754654e-5 7.802122225692164e-5 -7.00170266238501e-5 -1.2911653280014142e-5 5.597732987077975e-5 9.500299446225672e-5 8.402826741510123e-5 -5.146015987718269e-6 8.182760310298793e-5 1.1237197290736261e-5 6.948032485959247e-5 -1.1379018283354129e-5 2.4266909426638025e-5 -8.255587355527023e-5 5.390226383300268e-6 0.00017557947383902515 -9.51176996577494e-8 -5.515004787886371e-5 8.432897546734875e-5 -0.00015324695315504803 -4.2612190477683645e-5 -9.100905200759483e-5 -0.0001445601873594745 -3.618751990436117e-5 -0.00020300876978146583 -9.312044759573938e-5 -8.267975400960344e-5 -2.7634377070375978e-5 0.00010075038252105083 -5.778450842192137e-5; -6.468028096464793e-5 0.00023088313809380563 -4.416628599614987e-5 0.00021037719033443035 -6.691018007445634e-5 9.833393952757988e-6 0.00013173060778312828 7.736421593249672e-5 2.073763334151336e-5 -8.95987841508474e-5 -6.238878359768102e-5 3.995454516511918e-5 -2.325322294241284e-5 -7.780825490813221e-5 8.184553643472951e-5 2.3255365696558573e-5 -7.450283101975255e-5 -0.00014840388849592613 -3.334344456420602e-5 0.0001696721365001871 8.510347014776248e-5 8.46632337156807e-6 -8.485472881916311e-5 -8.224976137030332e-5 -0.0001727679273321421 2.841702545584163e-6 -4.80137487870722e-6 -7.964206543733755e-5 -7.678160273685439e-5 -0.00011036128945165798 -3.629644653160604e-5 7.12566825236989e-5; -4.2323813580289064e-5 -0.00020054265875237462 5.3833878006714125e-5 -0.00012453812192993768 -4.5567162636288035e-5 -2.2806515665538178e-5 9.833293767443426e-6 7.732850539165073e-5 2.5111738829437975e-5 1.2141755939092113e-5 7.0000767997415396e-6 -7.51911513028571e-5 2.087981814533311e-5 7.774995796047993e-5 0.00024123991064392775 5.547581425766238e-5 -8.251555224282916e-5 -0.00010259202932805154 -7.146287058581144e-5 9.054478902411952e-6 0.0001361166422035743 4.458195584142949e-5 -4.021563396540992e-5 -1.2663910786936315e-5 -1.8461392195169491e-6 -6.0883209847228136e-5 1.197348668633257e-5 -2.1667337171815476e-5 -0.0002291093780180299 0.00012493220023010627 -0.0002026263183901015 -0.0002665873991312366; -5.548871972262812e-5 0.0001280445244033827 7.841825886732103e-5 5.873198047499218e-5 -4.357266375731254e-6 -0.00023105643462752485 1.9082609882903953e-5 7.275244156932619e-5 -8.262231544489069e-5 -0.0001662360341883258 0.0001644232938402329 -0.00011853837763322514 -7.963202783269826e-5 -2.300517603128693e-5 -5.1466348436800976e-5 0.00015801383001832266 -4.250135366003882e-5 3.959302871823595e-5 2.7342180560087376e-5 -0.00012313753954497987 -0.0001583054295637633 -1.1559942818560676e-5 8.468010806115393e-5 5.942241792087462e-5 4.859181668013376e-5 -7.53757454235104e-5 -4.5701894619409e-5 8.946955362777426e-6 5.1005597428745205e-5 0.00016106823336901115 0.00011125809710808073 -8.643431333024054e-5; -9.011581315873667e-7 -9.175420754595228e-5 -5.148979284200648e-6 0.0001209329547290099 -3.0331850297566933e-5 3.111087977055125e-6 9.652165127214549e-5 0.00026940224589074274 -0.0001814430816867895 8.391657129855279e-5 5.552942630480606e-5 -1.7395388653439593e-5 4.482214890764509e-5 0.0001596835262770238 -0.00013327694077282693 2.4784503608192605e-5 -2.5038227335379894e-5 7.334774809880868e-5 9.600986041355481e-5 0.0001130743675329057 -3.506279071566892e-5 -4.322133290238031e-5 -1.979553425748704e-5 9.199075239472082e-5 0.00010408774497214431 0.00020631989710260076 3.804288090201302e-5 -2.4896409826525736e-5 3.766946056735283e-5 3.283872675261965e-5 -0.00012203937703283017 2.7496307580677988e-5; 2.885995242048455e-5 5.647454169173654e-5 -0.0001563028570095553 0.00013754489851943152 -0.00015074625183534842 -0.000142039753642727 -2.211939930622502e-5 0.00010210620463420528 -5.932671114043368e-5 -0.00015147632142235596 2.874114694658167e-5 -9.941826350460807e-5 0.00013675614105830865 -0.00011701131801293273 7.702136578098832e-5 -3.160220579649358e-5 -0.00014670231824913139 3.0696795205025294e-5 5.487470510932866e-7 0.00015131894111926577 -0.00012398379543502908 1.4395416895590552e-5 1.3728824581798909e-5 9.57293079985766e-5 -1.4207503897173497e-5 0.00010291328295660148 -9.904093291529452e-6 1.9179105277894452e-5 -4.30122388357769e-5 0.00017355061505748855 0.00011301392011172102 1.7485473348083443e-5; 3.803174518422134e-5 -1.9191687847828537e-5 -2.2461797875926316e-5 -0.00016496002167089928 0.00012105676619709416 1.6808931881917605e-5 4.129662743677332e-5 -3.0883207678321354e-5 -0.00010366509190619976 0.00011644900410343792 1.790370702534863e-5 -1.649960899643202e-5 3.121342711407634e-5 -0.00012148876516706373 -3.254911457164268e-5 4.529333642009592e-5 5.05663647842676e-5 0.00012815271661990117 0.00018411849492137232 3.127914356324701e-5 0.00014885937894602182 1.1625248639270717e-5 5.800972399325952e-5 3.4804821602543074e-5 -0.00010944100899481304 0.00012662867998612036 -6.208680674293147e-5 -9.752729051373047e-5 3.0332073997365303e-5 -8.046490069803836e-5 -0.00027254205446782583 -5.7854971551327696e-5; 2.442636142754417e-5 -0.00011454226440596972 3.370277082816675e-5 6.957908809803857e-5 2.263804200433526e-5 0.00013090307076147175 -1.2545777814149856e-5 -0.00022660342036593845 -3.2497915150226464e-5 4.9134158997212625e-6 7.819692147117925e-5 -6.631048707836593e-5 0.0001107062784411461 -2.4905226006151507e-5 -6.206874203249044e-5 -8.8090455015983e-5 1.3312788167625383e-6 -2.7877696617150833e-5 -3.993928074683597e-5 2.9451161402758325e-5 -9.627718633200878e-6 -3.111266013198087e-5 9.297035322910598e-5 0.00020564918705881207 2.9894536431897287e-5 -6.629768139296617e-5 -7.85430744018898e-5 -4.979194962848349e-5 0.0002656802902618835 6.584276017187956e-5 -4.054024963035794e-6 0.0001759984374201364; 0.00016610580331871303 -0.00011528999476468291 2.2305445690925895e-6 0.0002499948006425242 0.00019909841484150575 0.00015410623581370774 4.871217805885427e-5 5.91843098844738e-6 -6.473917329457855e-5 7.079978076035021e-5 4.8112551839954615e-5 0.00016186029660275608 1.5981795686461257e-5 -0.00014345743858029553 1.620964229914968e-5 -2.5500454112256987e-5 -9.441260936910536e-5 5.5325921666713906e-5 6.347791185362241e-5 2.696238897441827e-5 -3.115032978713189e-5 -4.660270297790819e-5 7.08899371511491e-5 -5.0972770450242944e-5 -1.1336351062710655e-5 3.8690888859491866e-5 -0.00014194194392075338 -4.932910252136857e-5 6.416920786250339e-5 7.541933718130641e-5 4.987846677068324e-5 -0.00010965240090626798; -0.00022469957958901046 -7.992238580709113e-5 -0.00015109318384429497 3.4078005669929866e-5 3.236929601832526e-6 -4.642067337509403e-6 6.704076030064896e-5 -4.207882785931369e-5 2.9153021433306527e-7 4.4387866624223156e-5 -0.0001355026402184127 -4.7184491008460074e-5 0.00016251467558095728 -8.204252709628551e-5 -2.6225956857609186e-5 8.72674931865796e-5 -8.153582940803692e-5 0.00016259105858398027 0.00010053895629016924 3.808214340597425e-5 7.910125667467715e-5 0.00014495054300414672 -1.6562975828943043e-5 0.00020702917343942812 -7.925069122802186e-5 -2.2674591049958427e-5 0.0001322453094740069 -2.2563618144426657e-5 1.9991826307436004e-5 8.063743688172251e-5 -0.00012483033834950484 6.849341979023496e-5; -3.316559922310351e-5 8.422513030415937e-6 -3.378568169839905e-6 -7.37441350683737e-5 -5.338184381612421e-5 0.0001849399213933678 -7.360596463327155e-5 5.344555585928899e-6 0.00015228782468698152 0.00020285231040546625 1.4081729215876843e-5 -6.55950916443134e-5 -1.311210324997967e-5 9.578809296003492e-5 7.668205399754543e-5 -2.035834403427651e-5 -7.392535734467412e-5 0.0001235692125301352 -2.6205247184372272e-5 0.00016355640583427025 0.00013121320263237912 0.00014993496916944392 -0.00014666069727169173 6.74317943885789e-5 0.00017454494992492965 4.600728244626389e-5 -0.00014373600969334305 -5.838966352526519e-5 -0.00012072040225390255 0.0001588842515157457 0.00013382546059078658 -1.8603639535244442e-5; -0.00010735638641096438 -8.607935449151088e-5 0.0001214986958303896 9.58546996345206e-5 5.56267776347614e-5 -0.00019318252013805464 3.799300322964892e-5 -2.2566494567278668e-5 3.053456096052294e-5 2.7781702397198046e-5 8.723010689675907e-5 9.74979746616833e-5 -2.732023419855045e-5 6.182290674144532e-5 9.966273939892762e-5 3.7555525364159744e-5 3.4654175419865934e-5 1.7275539744455446e-5 -8.269337120042892e-5 -6.427416080884325e-5 -3.114219286987233e-5 3.6989445790197794e-6 0.00015694338510468817 -0.00015989450473613862 -0.00015704788818299563 0.00011026341674108242 -9.139144383413572e-5 -0.00017994813531447025 9.747213721431256e-6 -4.1183414555116447e-5 -5.971267744359452e-5 9.565960210705321e-5; 1.6774149010678917e-7 0.00025378424435026193 -0.00013553193311103073 -0.00013438505301614185 -7.333636348498304e-5 -5.630800332303031e-5 0.00010499776509227917 -0.00020211991103727644 3.6212769459360363e-6 -4.5201326784573356e-5 0.00011464809323065643 0.00010045131749343952 -6.980664365123234e-5 -0.00013293514383853582 -0.00023491739592877437 0.00013674313196084223 -6.24302923738837e-5 0.0001285112881391528 -6.296527898534065e-5 2.030686447093218e-5 -3.194865020353811e-5 5.100535746301613e-5 -0.00011716566986300872 4.0309665650568924e-6 0.00013381455902113665 9.355603798163137e-5 5.6696636974679966e-5 -0.000106069368840093 -7.167398906907853e-5 -7.454909099628463e-5 -0.00010068541486411594 0.00011281249281587892; 0.00012063914495111656 -8.751137771586718e-5 9.7296672391807e-5 -1.567828745582389e-5 1.4382591542661953e-5 0.00013394612927976928 3.654101841558217e-5 4.135350412862237e-5 3.88687500514694e-5 5.761755768499494e-5 -0.00010893782142492483 -0.00010166624393722618 -0.00011184175797215244 0.0001432845589431136 0.00010998939855081458 -3.1823479149616546e-5 -6.589749806270818e-5 -5.4397250424973355e-5 -0.00017567878363597007 4.2108370179180194e-5 0.00012888169898242675 -7.421389578784588e-5 -4.69143500243199e-5 2.8333710258338985e-6 0.0001057421383290668 4.033209064683222e-5 -4.730538545235627e-5 3.604986581077168e-5 -8.619567357645177e-5 0.0001076481409097228 -9.313118185790746e-5 0.00010064613037567016; 3.8941470256480775e-5 -2.2973873913582097e-6 0.0001472312500024651 5.685349547065349e-5 -0.00016622675379977954 -1.6741933586904657e-5 2.165791861414994e-5 -0.00015275953833251271 -6.225356687565739e-5 2.5099999538448383e-5 -1.0773671101848352e-5 -5.220290463540737e-5 1.2458612897582225e-6 -0.00010180664016871904 -0.00016585470498313518 -0.00012791984104325432 -2.250690581631902e-5 -4.487810338355345e-6 -0.00010402095417730879 -9.778935113927018e-5 -6.815676956675498e-5 2.757636807445416e-5 -5.772265604087969e-6 6.22549190208611e-5 -0.00019034697529633098 -4.35475564248089e-5 -2.1788137610485553e-5 0.00014186483836027964 -6.113532494992362e-5 -5.8809608754305264e-5 -0.00013000166713965486 0.00015723911142202958; -8.000959434396032e-5 3.544668350984939e-5 -0.00011183413818327801 -0.00016493647791760422 6.749585071263216e-5 5.498028320761329e-5 4.137135588436757e-6 -1.3028196294563994e-5 -1.691620610911555e-5 -3.487944860792108e-5 1.8442375844799186e-5 -0.0001287106200059742 3.127238659105605e-5 0.00010086295577382106 5.59598562951326e-5 -6.011713714072162e-5 -7.091727394692878e-5 0.00010080236887477719 0.00022699630852473536 2.831894048701508e-6 4.86230898651932e-5 0.00017089035162387022 -4.349906989614542e-6 -8.013208509039603e-5 5.517075519650932e-6 -3.201077498927226e-5 5.4566150252115246e-5 7.81518471605284e-5 -0.00015617510445265455 -0.0001061567957684273 -0.00019061877475435522 -6.861989033025287e-5; 0.00010162911790838534 -0.0001671946330489967 8.570067153494161e-5 -5.557952603259041e-5 -2.1536906514580132e-5 4.88948470418596e-5 -2.5871290871787675e-5 2.8593149964652067e-5 2.0073220058197456e-5 -3.386534798768269e-6 1.4977818737765863e-5 -3.769128956183358e-5 -7.6210245780260534e-6 0.0001707631789585257 0.00019831545510978134 -7.330558140811783e-5 6.065026251315891e-5 -4.603299306757715e-5 -7.369861408652704e-5 0.00014570769131831384 7.702254230105734e-5 8.861835964206717e-5 -3.323116756210961e-5 -1.6149540114483545e-5 -3.279486840175345e-5 5.692337514200485e-5 4.7637048611151906e-5 0.00020828155253265347 -4.145717993299314e-6 -6.083370239057842e-5 6.397299590470924e-5 -9.139062161840043e-5; -0.00011116763703389734 0.0001051344465384762 0.00011285121804471888 -3.203912191528399e-8 4.72254527711299e-5 4.139933521298236e-5 -5.769365593656066e-5 -0.00011897789731951475 7.844498254493698e-5 -5.611916417469816e-5 -1.5754757943028925e-5 -3.2665533172944034e-5 0.00014873219556542984 -0.00014093457643971798 0.0001677087952419642 -9.931045274348094e-5 1.6232536335200898e-5 -0.0001781397345904572 -0.00011105541266365592 -0.00014650884763900497 -8.964141811806962e-5 1.6578509938742812e-5 -0.000135327839917547 -6.581275151907328e-5 -6.27447785074331e-5 5.2251658825159024e-5 1.796172223188951e-5 -0.0001127904884480391 0.00011029117963412172 -0.00011297376254438525 0.00011059692264902312 0.00018376993325441181; 0.0001334108648804394 -9.340016803644445e-5 7.427725104908237e-5 -0.00021352757429877258 -0.00013579502664680168 -2.7564989788763804e-5 7.09194911975736e-5 -1.7188244594121824e-5 -3.390088788646883e-6 -5.992732875920651e-5 6.10518593090905e-5 4.503319106124056e-5 9.493353464467008e-5 7.125438169868281e-5 -0.00012371673328044356 -0.00010645482943101349 2.3863921531027185e-5 -3.2641992940884674e-5 -8.127605796499095e-5 -6.512810322628768e-5 -2.8721370465335114e-5 7.114767978027474e-5 3.424427741965325e-5 -9.669273530772123e-5 -6.444875433981414e-5 -7.716704093791651e-5 8.995501236942778e-6 -3.778661537638653e-7 2.116646210897348e-5 -3.603670010516921e-5 -8.680322746213288e-6 -4.508840246881664e-5; -5.9668129619024476e-5 -0.00018533476970675292 7.840500045517125e-5 -8.083670975225965e-5 -5.85263499704236e-5 -0.00011046470019549398 0.00010860392801644756 8.375943587111128e-5 -8.785933307423726e-5 -0.00020260528828999823 0.00034675857175555524 -1.9263349098551917e-5 -0.00013966225953824293 -3.371034683711249e-5 2.1364734199080083e-5 -4.2761016014781636e-5 -6.865696043292767e-5 -3.467394555968544e-5 5.541147554707765e-5 6.684087350018238e-7 1.542581928901403e-5 -7.0921594964387e-5 -0.0001624430208937729 -7.44921838520577e-5 9.175707389316506e-5 5.0600339524167404e-5 -4.6619070711827936e-5 3.22571192339047e-5 3.7646974011511275e-5 5.553795905354282e-6 -2.2593825927986568e-5 4.4968308135352006e-5; -0.00020058706564109587 -6.450326103597228e-6 -3.673329324906224e-5 -0.00014166033638853397 0.00013014372684314792 0.00018787187304602492 -1.722986330956272e-5 0.0001359949500930736 -0.0001626654002883291 -4.193271755732042e-6 -0.00015345706474714047 0.00010735044356456066 2.067968575478971e-5 6.125464728588237e-5 2.6015782119561557e-5 -0.00013112942092798695 -0.00010157757338115663 -0.00011807195101377549 6.0144361619811636e-5 8.214805298984653e-6 9.05750830007945e-5 -7.89934447801469e-5 -4.588782672167573e-5 4.9049963259809635e-5 1.526742648470381e-5 0.00010661492428530127 2.948588595377826e-6 8.101912450134786e-6 -4.4801744528606336e-5 -7.791572993833624e-5 0.00014697790624827102 -9.452262297950654e-5; 4.823002149410133e-5 -0.0001450557925266811 0.00012576936336759252 -5.339921568461179e-5 -0.00013106933562026 -0.00015173071238617936 -3.074093416157749e-5 9.996050124814561e-5 -4.271455923967638e-5 1.9340056897500503e-5 -5.7422367347963095e-6 -7.374384091176602e-5 -0.0002008198230754267 6.962008541508973e-5 2.9131343981725143e-5 -2.411989094921491e-5 7.943219447102051e-5 -0.00014018611051082522 -8.89350872597325e-5 0.00012150434253330994 -4.4990333262237516e-5 6.560645744236337e-7 0.00010769871500147094 -3.624547215940641e-6 0.00011422533629288472 -3.9277578761675314e-5 6.226325553202487e-5 0.0001123306332744052 6.952236202837499e-5 -0.00013955170521453593 0.00011461080197536752 -3.4413055399773206e-5; 0.0002236108044263108 3.5967320912345576e-5 -0.0001239891939230069 -1.543757247138846e-5 -8.899132848176292e-5 0.00010316664662523461 2.0448956604985334e-5 -0.00012060148622996866 -3.826144205680242e-5 -0.00019358329194521938 -4.9797836651971164e-5 -0.00015402194588835262 5.9107401796671924e-5 -8.416359236155639e-5 1.633841162681035e-5 5.296917476068145e-5 4.1164551988073225e-5 3.7761179158307364e-5 -5.40795121532263e-5 -3.527616757939666e-5 0.00022456945549773023 -8.94652334290427e-5 -5.692314920172631e-5 -0.00011014080558827878 -8.722798560204557e-5 -7.098867393954402e-5 -5.627260583144545e-5 -1.648320761608218e-5 -0.00021322232390678044 0.00011527072598534524 7.826247707550314e-5 7.292610058636473e-5; 0.00011063767830191236 -7.561837468247833e-5 -0.0001039711298625743 0.00018668819917647226 2.853121601501227e-5 -0.00021859506726915767 4.427142596882545e-6 0.00010679349158007473 -0.00012662235006874212 -3.6904513978613364e-5 1.9637281878290117e-5 -0.00022424515394678192 -5.9329953944057143e-5 -3.6801224484320823e-5 -0.00019023961270877893 3.977494431009725e-5 0.00012052546459298573 0.00017755397653953665 1.4126612552984612e-5 5.371657040367758e-5 7.021859870573409e-5 -0.00011391335486851567 4.9501237083944124e-5 -4.818598187654278e-6 -6.822717317564344e-5 -4.310682401128944e-5 0.00010424945300027885 5.907007340226587e-7 7.651559805080336e-5 9.625562382184216e-6 0.0002045511553324966 9.933356137762001e-5; 8.750501040200448e-5 -8.582577681803348e-5 0.0001046788824402642 -6.445796277542859e-5 -1.6258512527510344e-5 -0.0002145926405995125 0.00014282736849569177 -4.180791688060323e-6 0.00011234477317490071 9.118236669170483e-5 3.841905747836832e-5 1.4993775426466405e-5 -0.00012757163456111718 1.5509757236633342e-5 0.00012315002134133545 3.35641123061641e-6 -9.578974966128495e-5 -1.9596037940870182e-5 4.598409252654522e-5 -0.0001213483843224218 8.450701027474386e-5 -0.00011940233487106204 -0.0001743895259772546 7.762222981602947e-6 4.947723612136859e-5 -0.00012758494956354943 -0.00010346607411914942 8.383139395451744e-5 -6.8947294106998e-5 3.805710405291453e-5 0.00011218040201643883 -0.00011804156345018622; 0.00013196141997700255 -8.103454381355253e-6 0.0001457219292956812 1.521679558440727e-5 -1.3447136841441245e-5 8.420391139459553e-5 -6.951630013414577e-5 -0.00031685608366155464 1.742890861582077e-5 -7.660137300031432e-5 0.00011338314260005065 -4.767540332746399e-5 -0.00011880910853274961 -0.0001199460496695413 -5.207470910553826e-6 0.0002806580993661014 3.5174045190205045e-5 -6.066330781089566e-5 -1.0635303003667291e-5 0.00014357659864122215 -2.51521271979506e-5 -0.00011427486999078064 5.935197958255812e-5 8.766422690809177e-6 -6.446736762650659e-5 -0.00017247978059227106 -5.395677755355564e-5 -0.0001021633837057887 -0.00021148875049898652 -6.74710138968836e-5 9.380709054803937e-6 1.978545486285159e-5; -7.127076507314073e-5 1.2896672234628763e-5 0.00020049330949549642 -6.435595689109493e-5 -9.401618710254323e-5 -7.57630722713727e-5 -8.812106824198299e-6 0.00012456326392711227 4.793328578815792e-5 -4.0754580294023774e-5 7.790789557619945e-5 -8.13686882960702e-5 -7.134320565828808e-6 0.00012115475427125585 1.5499244074444285e-5 9.829984339002964e-6 5.373644770005767e-5 0.0001563336018822078 -0.0001854128012550153 -0.00018361847734778123 0.00016610259361331421 5.087664711171793e-5 -0.00017668889897177895 8.67484605638219e-5 -6.0962137393662396e-5 -0.00018607120265952293 5.4098550282642345e-5 0.00011759021898032186 -8.115884967847691e-5 3.267291784285565e-5 -0.00016728657184868123 0.00011497715339063677], bias = [3.931374243361431e-10, -4.801485980288223e-10, -2.8165551116252515e-9, -1.831743213903872e-9, 1.4378680155189746e-9, -2.2763881363654307e-9, -6.339254931549423e-10, -2.714990259094748e-9, -3.0972426788868134e-10, 3.5414867531870683e-9, 1.230213322052055e-10, 1.546638174472464e-10, 1.550155364700247e-9, 2.646802391328331e-9, 1.0982016714103551e-9, 3.807155780679483e-9, -4.612728571050646e-10, -1.5064784259602285e-9, 9.737402339972889e-10, -3.704053885068092e-9, -9.961830328968792e-10, 2.4927984621366477e-9, -1.6348632530985608e-9, -2.301742197744436e-9, -2.0139157355407223e-9, -9.923077008940558e-10, -6.571639014823455e-10, -2.1177010532228877e-9, 6.455257743383835e-10, -1.1783179975754105e-9, -2.3031293130857126e-9, -1.5544938860204118e-10]), layer_4 = (weight = [-0.000825435821435167 -0.0007007992132732165 -0.0006483489001528379 -0.0006621591537974912 -0.0007680741191377385 -0.000688945969824155 -0.0007859908702743211 -0.000556680936183584 -0.0008596656237570879 -0.0008109975877878983 -0.0006431768884623387 -0.0006672960547306933 -0.0005941762154194607 -0.0006642904230520978 -0.0007453314611238785 -0.00045148785537058657 -0.0007105867519759002 -0.0006778572895875305 -0.0006833251520187173 -0.0007245295221803269 -0.0006756982794567802 -0.0007815581565435537 -0.000742764175664157 -0.0006438420221205177 -0.0006618519486548406 -0.0006253527161493031 -0.0007399677018588302 -0.0008268406553716065 -0.0008089805383241292 -0.0005698196227182137 -0.0005665347302628629 -0.0006967760858349021; 0.0003489843111848854 0.00046776910625721905 0.00041626651583320914 8.402328837655297e-6 0.00010727563844197917 0.0002235405254344032 0.00012351459189835173 0.0001014700131101573 0.00014480499939377344 8.88931183356929e-5 0.00020804572097136067 0.00026840689064702554 0.00041167416342430165 0.00030502740507573313 0.00013593065089145988 0.00022545846800012522 0.00024195265144940454 0.0002067099180986097 0.00021911305523425107 0.00026152018872605215 0.00021287351834953913 0.0001581153337360021 0.00031161093498153857 0.0003279693275583185 0.00016939607301015499 0.00018582665634883972 0.00027843608198012187 0.0003267133046513878 0.00024903493047606157 0.00027449889805143684 0.00019030277419365541 0.0002765743786196463], bias = [-0.0006820906855595285, 0.0002192351725549596]))Visualizing the Results
Let us now plot the loss over time
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")
lines!(ax, losses; linewidth=4, alpha=0.75)
scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)
fig
endFinally let us visualize the results
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
s3 = scatter!(
ax,
tsteps,
waveform_nn_trained;
marker=:circle,
alpha=0.5,
strokewidth=2,
markersize=12,
)
axislegend(
ax,
[[l1, s1], [l2, s2], [l3, s3]],
["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
position=:lb,
)
fig
endAppendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.12.5
Commit 5fe89b8ddc1 (2026-02-09 16:05 UTC)
Build Info:
Official https://julialang.org release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 4 × AMD EPYC 7763 64-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-18.1.7 (ORCJIT, znver3)
GC: Built with stock GC
Threads: 4 default, 1 interactive, 4 GC (on 4 virtual cores)
Environment:
JULIA_DEBUG = Literate
LD_LIBRARY_PATH =
JULIA_NUM_THREADS = 4
JULIA_CPU_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0This page was generated using Literate.jl.