Training a Neural ODE to Model Gravitational Waveforms
This code is adapted from Astroinformatics/ScientificMachineLearning
The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl
Package Imports
using Lux,
ComponentArrays,
LineSearches,
OrdinaryDiffEqLowOrderRK,
Optimization,
OptimizationOptimJL,
Printf,
Random,
SciMLSensitivity
using CairoMakieDefine some Utility Functions
Tip
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector
function one2two(path, m₁, m₂)
M = m₁ + m₂
r₁ = m₂ / M .* path
r₂ = -m₁ / M .* path
return r₁, r₂
endNext we define a function to perform the change of variables:
@views function soln2orbit(soln, model_params=nothing)
@assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
if size(soln, 1) == 2
χ = soln[1, :]
ϕ = soln[2, :]
@assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
p, M, e = model_params
else
χ = soln[1, :]
ϕ = soln[2, :]
p = soln[3, :]
e = soln[4, :]
end
r = p ./ (1 .+ e .* cos.(χ))
x = r .* cos.(ϕ)
y = r .* sin.(ϕ)
orbit = vcat(x', y')
return orbit
endThis function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d_dt(v::AbstractVector, dt)
a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
b = (v[3:end] .- v[1:(end - 2)]) / 2
c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
return [a; b; c] / dt
endThis function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d2_dt2(v::AbstractVector, dt)
a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
return [a; b; c] / (dt^2)
endNow we define a function to compute the trace-free moment tensor from the orbit
function orbit2tensor(orbit, component, mass=1)
x = orbit[1, :]
y = orbit[2, :]
Ixx = x .^ 2
Iyy = y .^ 2
Ixy = x .* y
trace = Ixx .+ Iyy
if component[1] == 1 && component[2] == 1
tmp = Ixx .- trace ./ 3
elseif component[1] == 2 && component[2] == 2
tmp = Iyy .- trace ./ 3
else
tmp = Ixy
end
return mass .* tmp
end
function h_22_quadrupole_components(dt, orbit, component, mass=1)
mtensor = orbit2tensor(orbit, component, mass)
mtensor_ddot = d2_dt2(mtensor, dt)
return 2 * mtensor_ddot
end
function h_22_quadrupole(dt, orbit, mass=1)
h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
return h11, h12, h22
end
function h_22_strain_one_body(dt::T, orbit) where {T}
h11, h12, h22 = h_22_quadrupole(dt, orbit)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
h11 = h11_1 + h11_2
h12 = h12_1 + h12_2
h22 = h22_1 + h22_2
return h11, h12, h22
end
function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
# compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2
@assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"
h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
@assert mass_ratio ≤ 1 "mass_ratio must be <= 1"
@assert mass_ratio ≥ 0 "mass_ratio must be non-negative"
orbit = soln2orbit(soln, model_params)
if mass_ratio > 0
m₂ = inv(T(1) + mass_ratio)
m₁ = mass_ratio * m₂
orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
return waveform
endSimulating the True Model
RelativisticOrbitModel defines system of odes which describes motion of point like particle in Schwarzschild background, uses
where,
function RelativisticOrbitModel(u, (p, M, e), t)
χ, ϕ = u
numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
denom = sqrt((p - 2)^2 - 4 * e^2)
χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = numer / (M * (p^(3 / 2)) * denom)
return [χ̇, ϕ̇]
end
mass_ratio = 0.0 # test particle
u0 = Float64[π, 0.0] # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4) # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize) # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, eLet's simulate the true model and plot the results using OrdinaryDiffEq.jl
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)
axislegend(ax, [[l, s]], ["Waveform Data"])
fig
endDefining a Neural Network Model
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but in case you do use them, make sure to mark them as const.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,
const nn = Chain(
Base.Fix1(fast_activation, cos),
Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-4.886792f-5; 0.0001607513; -0.00016998588; 0.000112922455; -5.5643148f-5; 0.00012008093; 0.00016823101; -0.00014414634; -5.4134664f-5; 6.5001914f-5; 0.00012562083; -0.00020957156; 0.00025607224; -4.056716f-5; 0.00017357882; -0.00010283758; 4.302182f-5; 0.00011678414; -3.3715173f-7; -1.0643625f-5; -9.7660726f-5; 0.00015326931; 0.00012373127; -3.4724708f-5; -6.359594f-5; -3.4404024f-5; 6.828547f-5; -4.150727f-5; 9.079988f-5; 2.713035f-5; 2.4925936f-5; 0.00013315189;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-8.639777f-5 2.7114074f-5 1.1285393f-5 -6.15958f-5 -4.1909792f-5 9.666661f-6 0.00014247296 1.8063154f-5 -9.082773f-6 0.00012621994 -7.371618f-5 2.0600517f-5 -1.0987127f-5 3.4255456f-5 1.0597467f-5 4.8088972f-5 -4.276447f-5 3.078342f-5 -7.659192f-5 -1.4155848f-5 -0.0001114802 7.0567417f-6 1.3332735f-5 0.0001102505 5.850531f-5 -2.5062567f-5 -0.000114554045 -3.28546f-5 5.727198f-6 6.4862696f-5 -4.531598f-5 -0.00014974468; 4.954859f-5 0.00019201511 -1.6117006f-5 -3.165815f-5 0.00015777773 3.6354795f-5 1.6857968f-5 -3.6907728f-5 4.075585f-5 4.7491754f-5 3.3040305f-5 0.00020875578 5.011024f-5 0.00017462781 -2.6675158f-5 4.7963957f-5 2.1373427f-5 7.1594295f-5 -6.544711f-5 0.00015513769 -0.000102283986 -6.392143f-6 1.4686851f-5 -0.00014822964 8.3562554f-5 -2.9581994f-5 -9.470415f-6 1.2771163f-5 0.00016862362 0.000110849665 -2.3206836f-5 -0.00012485635; 1.3754341f-5 -5.9877864f-5 9.3300136f-5 -0.00011830192 6.795434f-5 -0.00013544045 0.000174962 0.00024125085 0.0002142657 1.2538985f-5 0.000117564115 -9.2704984f-5 -6.5207736f-5 0.00014184535 -0.000108311884 -0.00022723664 -5.379021f-5 -0.00011700593 -0.00011865526 9.626507f-5 -3.6974143f-5 -7.872712f-5 -6.945486f-5 1.4383761f-5 3.651037f-5 -1.34506545f-5 7.398752f-5 3.4601908f-5 -0.00013118179 -3.0447782f-5 5.188741f-6 4.801835f-5; -0.00010011431 0.00023274473 -0.00014031277 -0.00014190108 0.0001228996 -1.9607183f-5 7.835892f-5 3.589852f-5 -3.644083f-5 1.9786037f-5 2.2173555f-5 7.172859f-5 -1.9944973f-5 -3.509767f-5 5.939644f-5 2.3875229f-5 -1.9535437f-5 -7.888036f-5 6.993422f-5 -5.6000276f-6 -6.80405f-5 -6.0298258f-5 0.00014938756 -2.813075f-5 -5.3540705f-5 1.5423258f-7 -4.5206205f-5 -3.0122897f-5 -1.249613f-5 0.000116096766 -0.00020166615 3.8695744f-5; -0.00018479592 -9.52388f-5 0.00016576068 1.17286945f-5 0.00020045377 0.00010586406 2.3736777f-5 0.0001309303 -0.00014428908 -7.510651f-5 0.00012507585 -0.00014571573 4.1951243f-5 -8.774507f-5 -4.1499443f-5 0.00010541086 6.741873f-5 -7.7071505f-5 0.00010396478 2.155342f-5 -4.0615264f-6 8.835955f-5 8.165024f-5 0.00013789473 -2.1863225f-6 -0.00013215015 5.616623f-6 -9.57747f-6 -0.00013115112 4.6955418f-5 -6.608177f-5 -1.189604f-6; -8.27319f-7 -4.1424457f-5 -6.603957f-5 9.583781f-5 2.819778f-5 0.00012828592 2.1419759f-5 0.000107736174 -3.0851304f-5 8.089825f-5 -4.9044083f-5 2.5714646f-5 -4.5306246f-5 6.323601f-5 8.751697f-5 9.32066f-5 0.00019476414 0.00012779853 -7.54728f-6 0.0001506614 -4.8058024f-5 -5.9449292f-5 6.9600836f-5 0.00020685646 -8.242795f-5 7.790167f-6 -1.31433535f-5 -3.3284014f-5 -9.4996045f-5 -7.8648745f-6 0.00012726059 -0.00012905193; -3.796512f-5 4.51526f-5 7.872728f-5 1.3881862f-5 -4.9126153f-5 4.4865665f-5 -5.0235394f-5 2.044982f-6 0.00011488953 1.2169713f-5 0.00016546845 2.638021f-6 -0.00012753883 -0.0002894567 4.835513f-5 -3.947531f-5 -1.1616626f-5 -4.9639355f-5 0.00012306793 4.1974677f-6 -4.6242567f-5 -8.030426f-5 4.8264308f-5 -0.00016711865 -7.406864f-5 -1.0386588f-5 -8.8354855f-6 -6.379834f-5 4.4877575f-5 0.00016541003 0.00014980689 -2.0014779f-5; -9.5959156f-5 -0.00012508393 -0.00012893551 -0.00010694171 0.00012454126 1.6024396f-5 -0.00014229225 -2.2854714f-5 0.00013592308 -0.00011637197 -0.00014550133 -0.00017859893 -0.00014177305 -3.0370153f-5 -3.644539f-5 4.705326f-5 -0.00012995812 0.00013187088 -0.00020260744 0.0001618224 4.0999636f-5 6.786643f-5 -0.00019706576 0.000117140495 -0.00016661904 -8.236149f-5 1.9020892f-5 0.00016950269 -8.90234f-5 -1.2026283f-5 -0.00013349084 0.00012329043; 0.00014993527 2.519911f-5 5.8838024f-5 -9.521175f-5 4.261607f-5 -0.00013069267 7.919915f-5 0.00029707662 -2.2598804f-5 -8.9243746f-5 -0.00013323163 3.1778705f-5 -2.1579632f-5 3.13754f-5 3.668503f-5 -2.1458625f-5 -0.00012650824 -4.910099f-5 7.405413f-5 -9.352312f-5 2.32276f-5 8.23175f-5 1.1988871f-5 -7.1559385f-5 -9.394552f-5 -1.3751881f-5 0.00013123117 -5.661245f-5 -0.00011263218 7.619873f-5 -0.00010968813 0.00015392854; -5.277373f-5 -8.26476f-6 -8.204415f-5 1.16980955f-5 -0.00011611779 -6.61165f-5 -4.0092513f-5 -3.6228437f-6 -6.245071f-5 1.9087516f-5 -1.6695565f-5 -3.1113235f-5 9.148957f-5 5.9207643f-5 2.3995108f-5 -1.1247667f-5 1.713301f-5 -8.653513f-5 -2.810897f-5 2.0336023f-5 0.00016056678 -5.11337f-6 -6.0309285f-5 6.131961f-5 -6.660317f-5 -3.984339f-5 -1.2064733f-6 -1.558614f-5 -9.2018985f-5 -5.05365f-5 -0.000100035555 0.00011182672; -0.0001649741 2.7378883f-5 2.8512815f-5 7.296308f-5 0.0001453896 8.7368484f-5 1.1458011f-5 4.9859947f-5 5.20011f-5 -0.00010968639 4.482818f-5 0.00010955459 7.632106f-5 -0.00010345365 -0.00024888563 2.686566f-5 9.862197f-5 2.0104511f-5 -7.3502546f-5 0.00014520352 8.139957f-5 -9.657494f-5 -7.19465f-5 -7.2088806f-5 0.0001320429 -0.00016867716 -0.0001257934 0.00010296859 4.334047f-5 1.350573f-5 2.7602015f-5 0.00020073955; 0.0001066194 7.530869f-5 -0.00013877929 -0.00012322095 -4.8979215f-5 -2.926916f-5 -7.96112f-5 -3.0410178f-5 -0.00019881292 0.00033038476 0.00020708582 -2.0753687f-5 -6.933323f-5 -0.00010245584 -0.00014772806 0.00013550342 -0.00015538116 4.26239f-5 -4.2729094f-5 -0.00014038995 -8.774597f-5 -3.448404f-5 0.0001431531 8.939369f-5 1.1743092f-5 1.723586f-5 0.00017430399 -7.892009f-5 3.2287012f-6 0.00011609354 2.8704108f-5 1.4526721f-5; 0.00024483196 9.164193f-5 0.00011361486 -4.021527f-5 -9.29928f-6 0.00018361412 -0.00013271192 -8.4844694f-5 -1.15949515f-5 0.0002660695 2.6528665f-5 -2.6642232f-7 -3.3702137f-5 -3.5498277f-5 -0.00016327552 -5.8810615f-6 6.405243f-5 1.1885932f-5 -6.203822f-5 0.00013758447 -0.00010036377 0.00012525237 -2.6702839f-5 9.2960974f-5 -8.079474f-5 -1.0802163f-6 2.521469f-5 4.5756933f-5 -0.00014652572 5.535145f-5 8.860597f-5 -8.1286435f-6; 4.158926f-5 -5.9291044f-5 -4.4835695f-5 0.00010207732 -0.00017693195 -1.9484834f-5 4.2506763f-5 -8.447285f-6 -0.00017733473 -0.00011085989 -7.766174f-6 5.047993f-5 -0.00011268644 0.000112324335 0.00015028869 -3.6985708f-5 -0.00012928936 7.008631f-5 -0.00020616324 -1.7953002f-5 -9.015116f-5 -4.5963956f-5 -5.1158742f-5 9.02911f-5 0.00012278302 5.1689934f-5 3.6572717f-5 -9.1191185f-5 -8.190628f-5 -0.00017650738 -3.0022982f-5 9.8950084f-5; -2.4463927f-5 7.231759f-5 0.00011679475 0.00013238503 -0.00014836926 -5.013619f-5 0.00011814037 3.180509f-5 0.00018235725 0.0001291855 -8.561894f-5 1.8489556f-5 -3.7908223f-5 4.368082f-6 2.2743004f-6 -7.086442f-5 -1.6472579f-5 -2.7736389f-5 0.000106503365 -0.00014994209 3.36575f-5 -0.00012021562 -4.984807f-5 9.156894f-5 8.8586254f-5 7.1183815f-5 0.00012914832 -1.939879f-5 1.0390823f-6 -0.0001752606 -5.1306797f-5 -7.142395f-5; 0.00013406975 -4.1647116f-5 -7.1023074f-5 -9.0826525f-6 0.00011298758 0.000118429896 0.00015423204 7.4133175f-5 -0.00011318897 0.00018262585 -9.940116f-5 0.0001619968 -5.781193f-5 -2.6418804f-5 6.603824f-5 0.00016171522 0.00023508814 -2.691929f-5 -7.7185636f-5 -9.1302296f-5 -8.6548614f-5 2.8494376f-5 -4.3195334f-5 9.615993f-5 -0.00012857055 0.000122472 9.1419504f-5 -8.431976f-5 -7.996826f-5 8.776164f-5 2.0754052f-5 -0.00013360151; 0.00017546008 1.5063222f-5 -8.951709f-5 -8.7592794f-5 -9.317646f-6 -3.637644f-5 -0.00011132391 -1.6016013f-5 -6.892841f-5 -5.8117916f-5 -1.1127949f-5 -2.916764f-5 8.785113f-8 8.4020365f-5 6.236781f-5 1.6219103f-6 4.326269f-5 -4.951567f-6 8.7781635f-5 -0.00012503026 4.4690292f-5 0.00021889173 5.959576f-5 4.5608504f-5 -8.575898f-5 1.3629542f-5 7.5720534f-5 -2.8461907f-5 0.0001440856 -3.2065782f-5 3.9249917f-5 -5.7213503f-5; -9.3554954f-5 6.8994785f-5 -2.1538313f-5 4.922363f-5 3.696824f-5 -4.5007964f-6 -4.317177f-6 -2.4658773f-5 -3.1753658f-5 0.0001916147 0.000109873574 1.1927702f-5 -0.00013939186 -0.0001404335 -1.8772407f-5 -5.5322876f-6 -2.0088226f-5 7.21918f-5 9.8433724f-5 0.00020269932 -2.3770925f-5 -2.3228413f-5 -0.00014223326 -5.638284f-5 -5.836759f-5 -2.0755548f-5 -4.7688416f-5 -9.815007f-6 0.00011634451 5.008511f-5 -0.000121925645 -2.9444052f-5; -3.2297667f-5 -0.0003223577 8.6513755f-5 -3.5022884f-5 8.681524f-5 -9.2304814f-5 -1.20384275f-5 0.0003237391 -0.00018401493 -1.0041934f-5 -0.000115415365 -7.617676f-5 -0.00010099705 -4.3648775f-5 -0.00010424761 8.1329716f-5 -3.262718f-5 0.00017673211 -2.6581683f-6 -1.8038416f-5 -3.6528207f-5 0.000117844385 0.0001150997 8.7759356f-5 -0.00025316887 4.102976f-6 8.790108f-5 3.3467154f-7 0.00011126412 0.00020381024 0.00020697972 0.00012105592; 4.1091018f-5 0.00010275405 -7.0800386f-5 7.3516196f-5 -4.6709214f-5 -9.045129f-5 9.680077f-5 5.442802f-5 4.0661625f-6 3.1399684f-6 6.579872f-5 -6.6006155f-6 -8.464059f-5 -0.00010232839 -3.1194544f-5 -0.00016222653 -3.3911205f-5 -1.23482005f-5 -2.0626156f-5 -6.4129024f-5 -7.281506f-5 8.235157f-5 5.183688f-7 -5.223901f-5 1.017323f-5 -4.2751028f-5 9.169472f-5 9.85475f-5 0.00013218698 -4.2520027f-5 -0.00022805126 -4.9158705f-5; -5.6669287f-5 7.516043f-5 -4.0042614f-5 -2.7037404f-5 0.0001999405 1.3313411f-5 -0.00029703105 3.5416866f-5 -0.00013207694 0.00021403193 3.3071137f-5 -0.0002739146 9.344134f-5 0.000108530985 3.610462f-5 -6.469387f-5 -2.5806105f-5 2.586263f-5 2.9741961f-5 1.5317775f-5 0.00019918106 0.00010390689 0.00011809402 6.856673f-5 0.000103621285 -0.00014306823 -3.477557f-5 -0.00019493904 -6.960893f-5 -9.6213196f-5 -0.0001240427 -1.6026377f-5; -9.2773764f-5 -9.012788f-6 -0.00013463538 -4.4200817f-5 -7.136163f-5 8.309872f-6 4.74863f-5 -7.683559f-5 8.139511f-6 2.8466839f-5 3.383171f-7 -0.00024486362 0.0001138264 -0.00012656553 7.1185365f-5 0.00012673305 0.00012284523 0.00010210608 0.00016326945 8.422139f-5 -3.5267884f-5 -7.2943426f-6 0.00022968439 1.6390983f-5 7.022906f-5 6.1134546f-5 0.00011691386 -1.5030206f-5 0.000100382196 -0.00025271229 -6.0350678f-5 4.2861006f-6; -2.4142675f-5 -0.00014870364 -6.717999f-5 -4.8756515f-6 0.00020294062 -0.0001543804 -0.00022684394 2.731008f-5 8.700014f-5 5.9526552f-5 -0.00011120646 8.582629f-6 0.00023289613 7.961062f-5 1.8073968f-5 3.2439268f-5 -0.00020669094 -7.7016564f-5 -9.060842f-5 -9.144799f-5 -0.00011102286 -2.5027399f-5 6.0903913f-5 -0.0001149068 2.337527f-5 -1.3203983f-5 5.970167f-5 0.0001652669 0.00010496772 -1.2128982f-6 4.0452862f-5 -2.127846f-6; -3.571571f-5 -1.7358598f-5 5.9049693f-5 -1.413318f-5 7.4446194f-5 3.6509446f-5 8.591829f-5 -0.0001424876 -2.2088454f-5 7.952788f-5 -5.9493446f-5 -0.00012987659 -3.6782858f-5 -5.699596f-7 0.00010415621 9.131561f-5 -8.02297f-5 8.57582f-6 3.3053024f-5 2.3968445f-5 6.2464074f-5 -6.306649f-5 0.00010320011 0.0001846834 -0.00011329913 0.00025855424 -1.6619455f-5 0.00012975233 -0.0001276698 5.8547652f-5 4.009395f-5 0.00017817938; 0.00016829115 -1.4565419f-5 4.5808665f-5 -7.714915f-5 -8.403862f-5 7.034612f-5 1.7318292f-6 5.2668387f-5 -6.242826f-5 -4.435472f-5 -2.5436402f-5 2.077196f-5 6.483364f-5 2.004385f-5 0.00022816892 -7.22063f-5 -3.721029f-5 3.7618563f-6 0.000112536334 3.512497f-5 -3.7701843f-6 4.5680237f-5 9.0251044f-5 9.269184f-5 -0.000115321986 0.0002051318 6.3218016f-5 0.00015753592 0.00018791357 2.2822007f-5 1.9934638f-5 -0.00013042049; -0.00011267473 -0.00017550286 0.00014293354 8.706469f-6 -7.0910624f-5 0.00010372094 7.471722f-6 -6.217543f-5 2.4612897f-5 8.288393f-5 -0.0002153812 -4.128575f-5 7.270066f-5 3.9270464f-5 4.0435512f-5 -6.815707f-5 0.00011443112 -3.753458f-5 0.00020578434 9.827384f-5 -3.9607912f-5 -2.8979119f-5 8.708893f-5 3.3597753f-5 5.6318117f-5 9.62628f-5 -0.00017889806 3.037904f-5 -2.1166488f-5 9.577937f-5 2.6341453f-5 -0.00010210503; -9.747741f-5 -9.8242046f-5 6.182778f-5 2.7417924f-5 8.4133393f-7 -1.3585179f-5 -0.00018139767 7.521431f-5 -4.925447f-5 -5.5544813f-5 -2.3945604f-6 8.851024f-6 0.00017566716 -3.556733f-6 -0.00026990697 -7.526457f-5 -7.594314f-5 -0.00012490398 -9.753221f-5 0.000119398246 0.00010433148 9.155443f-5 0.00015419019 -2.3564695f-5 -2.8550598f-5 0.00019861163 -0.00024557294 9.1277805f-5 -0.00014598848 -1.2167788f-5 0.00014053367 -7.839075f-5; -4.182563f-5 -4.413116f-5 0.00010113374 4.5405664f-5 0.00010842847 -8.402249f-5 -5.9780773f-6 8.435427f-5 -9.684437f-5 0.00017255038 4.5791156f-5 -2.5697207f-5 4.408693f-5 2.7826702f-5 8.485049f-5 -7.1793555f-5 0.0001044116 -3.012515f-5 1.7138877f-5 4.1746076f-5 -1.6195346f-6 7.983184f-5 -2.8809865f-5 4.4834153f-5 -0.00015252836 0.00012975863 0.00013579878 8.3960025f-5 1.8438996f-5 1.5692025f-5 -1.1215072f-5 -9.8213524f-5; -7.463183f-6 5.1288084f-6 -2.8847637f-5 -9.17982f-5 0.00014674346 -7.64093f-5 0.00010162704 0.0001334277 -0.00018007167 -8.304577f-5 0.0001450819 -0.0001389555 9.385639f-5 9.851965f-5 2.8233737f-5 9.202802f-5 0.00011114671 -0.000110094225 3.1828353f-5 -8.583175f-5 1.8675006f-5 -4.820728f-5 5.1175888f-5 -0.00013448457 -2.378106f-6 -5.045088f-5 0.00014722995 0.00011870315 8.4148036f-5 5.2356078f-5 2.5346997f-5 0.00019204545; 4.7286314f-5 0.00016551517 -0.0001558569 1.390274f-6 -8.445685f-5 0.00013474905 0.00022678601 -8.6356195f-5 7.072091f-5 0.00011644649 -4.98892f-5 5.839051f-6 6.215163f-6 -0.00016999393 8.469836f-6 -0.00017423529 6.647214f-6 6.887139f-5 -0.00011640317 0.00017579137 -8.189067f-5 5.128787f-5 9.9748984f-5 0.00015655368 -0.00010806924 0.000266567 7.8105324f-5 -0.00013113335 -3.2762506f-5 -0.00014203465 -0.00028136934 -0.000137778; -6.1924547f-6 -4.2543466f-5 2.1525517f-5 -3.0932133f-5 -0.000103427126 -2.4285753f-6 -9.974799f-6 9.6913165f-5 0.00016731858 -6.2964946f-6 5.9643018f-5 0.00013310344 5.9780214f-5 -2.1209178f-5 -5.5606088f-5 -0.00025869903 -0.00018638992 -0.00018147011 0.0001185317 -0.00011910625 0.000102277845 -0.00021721268 -5.2448686f-5 0.00012091029 0.00010463823 -1.7909351f-5 -3.294434f-5 0.0001203356 6.097922f-7 -7.778094f-5 -5.620292f-5 -7.8909994f-5; 3.1503723f-6 4.3562348f-5 -0.00020854254 -5.052567f-5 -5.0489023f-5 8.903715f-5 7.392609f-5 -6.678187f-5 2.1989044f-5 -9.63049f-5 -9.2055365f-5 7.0257374f-6 2.7840792f-5 -0.00010408356 -4.8048707f-5 -0.00011413828 3.634752f-5 0.00016439884 -4.6706657f-5 -3.4685196f-5 -0.00010622056 0.00015864542 0.000106187654 7.242671f-5 0.00012009998 -3.4031975f-5 0.000106404914 -1.935752f-5 -2.821352f-6 0.00017883259 8.2122475f-5 -5.9080096f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[0.00013079506 0.00019029129 1.1013596f-5 6.224614f-5 -0.000113954775 -0.00014304469 9.398711f-6 2.2067314f-5 3.7122f-5 0.000116327195 4.135695f-6 -9.3335975f-6 -6.506828f-5 0.00010943395 -1.2086076f-5 0.00014497654 2.3068567f-5 -0.00013207072 0.00011511902 -0.00018693108 -3.7297024f-5 7.904784f-5 8.6358136f-7 3.7613805f-5 3.776658f-6 2.325723f-5 -4.2849482f-5 -1.51812965f-5 -0.000101230624 -6.119555f-5 -4.8569935f-5 -7.067719f-5; -6.9864626f-7 0.0001325139 -5.7863104f-5 -1.8433813f-5 6.440277f-5 -0.00012915829 0.00011010622 0.00012925807 -1.2280111f-5 -4.884944f-5 0.0003023248 0.00014899354 -5.4882683f-5 1.281804f-5 -3.2329255f-5 9.5123585f-5 -9.0838956f-5 -3.711097f-6 2.435719f-6 -9.378279f-5 9.4893214f-5 -0.00011593158 -6.0751143f-5 9.195147f-6 -0.00013302198 3.9038856f-5 -4.6503064f-5 9.8662924f-5 8.104716f-5 -6.320438f-5 -2.150336f-5 -4.8432517f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))Similar to most deep learning frameworks, Lux defaults to using Float32. However, in this case we need Float64
const params = ComponentArray(f64(ps))
const nn_model = StatefulLuxLayer(nn, nothing, st)StatefulLuxLayer{Val{true}()}(
Chain(
layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
layer_2 = Dense(1 => 32, cos), # 64 parameters
layer_3 = Dense(32 => 32, cos), # 1_056 parameters
layer_4 = Dense(32 => 2), # 66 parameters
),
) # Total: 1_186 parameters,
# plus 0 states.Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses
where,
function ODE_model(u, nn_params, t)
χ, ϕ = u
p, M, e = ode_model_params
# In this example we know that `st` is am empty NamedTuple hence we can safely ignore
# it, however, in general, we should use `st` to store the state of the neural network.
y = 1 .+ nn_model([first(u)], nn_params)
numer = (1 + e * cos(χ))^2
denom = M * (p^(3 / 2))
χ̇ = (numer / denom) * y[1]
ϕ̇ = (numer / denom) * y[2]
return [χ̇, ϕ̇]
endLet us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
axislegend(
ax,
[[l1, s1], [l2, s2]],
["Waveform Data", "Waveform Neural Net (Untrained)"];
position=:lb,
)
fig
endSetting Up for Training the Neural Network
Next, we define the objective (loss) function to be minimized when training the neural differential equations.
const mseloss = MSELoss()
function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
return mseloss(pred_waveform, waveform)
endWarmup the loss function
loss(params)0.0007278774577885794Now let us define a callback function to store the loss over time
const losses = Float64[]
function callback(θ, l)
push!(losses, l)
@printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
return false
endTraining the Neural Network
Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
optprob,
BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
callback,
maxiters=1000,
)retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-4.886792157770895e-5; 0.00016075129678929363; -0.00016998588398548235; 0.0001129224547185144; -5.564314778887167e-5; 0.00012008092744493718; 0.00016823101032039243; -0.0001441463391528674; -5.413466351462354e-5; 6.50019137537045e-5; 0.00012562083429655334; -0.00020957156084465978; 0.00025607223506072087; -4.0567159885522505e-5; 0.00017357882461494458; -0.00010283757728744105; 4.302182060189786e-5; 0.00011678414011834077; -3.371517323098137e-7; -1.0643624591460247e-5; -9.766072616920061e-5; 0.0001532693131593447; 0.0001237312681039869; -3.472470780250917e-5; -6.359593680832716e-5; -3.440402360861348e-5; 6.828547338948212e-5; -4.1507271816941054e-5; 9.07998764886519e-5; 2.7130350645127808e-5; 2.4925935576866104e-5; 0.00013315188698483638;;], bias = [-6.443335495390771e-17, 9.859246665347182e-17, -1.8103905845340923e-16, -1.6050637452914926e-17, -7.664294183170835e-17, 1.6889242139994989e-16, 2.243924872607419e-16, -8.024660964663579e-17, 7.083520070650748e-17, 1.099868739216659e-16, 1.6116331799382436e-17, -1.2482269528775812e-16, 1.4623630775605008e-16, -1.998800160224276e-17, 1.8457128483069351e-16, -1.475750080461831e-16, 1.2155216851497726e-16, 6.533154783925562e-17, -5.467776609551338e-19, -7.047509837412942e-18, 8.453073509412794e-17, 6.098079707109613e-17, 1.993253378368572e-16, -1.8887025031078053e-17, 6.558395641841723e-17, -5.0241503457494066e-17, 7.714538897335094e-17, -1.452633593523543e-17, 1.0176479560522766e-16, 4.191630638467515e-17, 4.0994570496589855e-17, -1.4461953537210081e-16]), layer_3 = (weight = [-8.639841024140715e-5 2.7113433423634988e-5 1.1284752252520594e-5 -6.159643879235396e-5 -4.191043324843853e-5 9.666019939881625e-6 0.00014247231530882133 1.8062513225428082e-5 -9.083414335498753e-6 0.00012621929574216863 -7.371682357350908e-5 2.0599876114121078e-5 -1.0987768002245394e-5 3.425481492578149e-5 1.059682592150974e-5 4.808833120315616e-5 -4.2765110885546453e-5 3.0782778883998935e-5 -7.659255850902695e-5 -1.4156488470497356e-5 -0.00011148084075265305 7.056100747119288e-6 1.3332093947805992e-5 0.00011024985662353312 5.8504668521174766e-5 -2.5063208051313706e-5 -0.00011455468546208765 -3.285524206311343e-5 5.72655728157144e-6 6.486205467646089e-5 -4.531662187381572e-5 -0.0001497453201243704; 4.955297476672488e-5 0.00019201949539535083 -1.611262273685173e-5 -3.1653766248483734e-5 0.00015778211304396216 3.635917887266809e-5 1.6862351864678635e-5 -3.690334420127676e-5 4.076023464849995e-5 4.7496137946747586e-5 3.3044689141101054e-5 0.00020876016419319078 5.011462412476847e-5 0.00017463219769243177 -2.6670773925742066e-5 4.796834031994151e-5 2.1377811160078254e-5 7.15986788662915e-5 -6.544272405616679e-5 0.00015514207547940556 -0.00010227960233633564 -6.3877592339195355e-6 1.469123430696042e-5 -0.0001482252519536754 8.356693792566345e-5 -2.957760994203763e-5 -9.466030908787427e-6 1.2775546517246562e-5 0.0001686280047096222 0.00011085404907763159 -2.320245226027738e-5 -0.00012485196571389596; 1.3754059403761923e-5 -5.987814598160147e-5 9.329985390973956e-5 -0.00011830220115047348 6.795405556072923e-5 -0.00013544073386259763 0.00017496171698763705 0.00024125056574444888 0.00021426542138538763 1.2538703452772332e-5 0.00011756383377683907 -9.270526584845153e-5 -6.520801776463086e-5 0.00014184506493944172 -0.00010831216582704576 -0.00022723692177936856 -5.379049048901105e-5 -0.00011700621485619268 -0.00011865554348009225 9.626478480964383e-5 -3.697442458036787e-5 -7.872740184797276e-5 -6.945513974318658e-5 1.4383478850102578e-5 3.6510087960376255e-5 -1.3450936200062686e-5 7.398723595866459e-5 3.460162610604857e-5 -0.00013118207221527093 -3.0448063404439465e-5 5.188459220842467e-6 4.8018069941062646e-5; -0.00010011453493388631 0.00023274450425634068 -0.00014031299294328385 -0.0001419013053866308 0.00012289937215501185 -1.9607408166115033e-5 7.83586954010457e-5 3.589829496450596e-5 -3.6441054607343616e-5 1.9785812338164405e-5 2.2173330517768163e-5 7.172836260980036e-5 -1.9945198136328595e-5 -3.509789464187141e-5 5.939621472026162e-5 2.3875004199731953e-5 -1.953566176707189e-5 -7.888058395946576e-5 6.993399504681771e-5 -5.600252524522801e-6 -6.804072406266095e-5 -6.0298483066792735e-5 0.00014938733577838002 -2.813097423240345e-5 -5.3540930156663666e-5 1.5400769046242893e-7 -4.5206429848980724e-5 -3.012312233221982e-5 -1.2496354674289313e-5 0.00011609654096064816 -0.00020166637047549984 3.8695518679384525e-5; -0.00018479488491797766 -9.52377638176157e-5 0.00016576171335977598 1.1729730685285143e-5 0.000200454803488497 0.00010586509895724827 2.3737812879650117e-5 0.00013093134046280192 -0.00014428804272105883 -7.510547286150983e-5 0.0001250768866877032 -0.0001457146979381465 4.195227964294861e-5 -8.774403679203163e-5 -4.1498406718214454e-5 0.00010541189410937729 6.741976838307528e-5 -7.707046890650269e-5 0.00010396581936143475 2.1554456251721287e-5 -4.060490228627573e-6 8.836058557712747e-5 8.165127954193942e-5 0.00013789577067578303 -2.1852863218053313e-6 -0.00013214911152394397 5.617659352678823e-6 -9.57643416845162e-6 -0.00013115007888777323 4.695645409734099e-5 -6.60807315587986e-5 -1.1885677812187004e-6; -8.237693092817461e-7 -4.142090716340129e-5 -6.60360174690084e-5 9.584135745246986e-5 2.820132880591593e-5 0.0001282894654112392 2.142330862698777e-5 0.00010773972366297716 -3.084775476887532e-5 8.090180103030196e-5 -4.904053372154268e-5 2.571819544826867e-5 -4.5302696034232374e-5 6.323955944102696e-5 8.752052139316641e-5 9.32101529004552e-5 0.0001947676927928482 0.00012780207811449645 -7.5437301885879265e-6 0.00015066495592767987 -4.805447411778853e-5 -5.9445742462400606e-5 6.960438526295855e-5 0.00020686001420548678 -8.24243989510818e-5 7.793716523193544e-6 -1.3139803794877628e-5 -3.328046392121855e-5 -9.499249523211594e-5 -7.861324828953181e-6 0.0001272641374642644 -0.00012904837882086182; -3.796536681353727e-5 4.515235353732056e-5 7.872703439328417e-5 1.3881616483318445e-5 -4.912639839336883e-5 4.486541887284672e-5 -5.023563995952192e-5 2.0447362073685528e-6 0.0001148892858834504 1.216946724137858e-5 0.00016546819982589652 2.6377751835192158e-6 -0.0001275390807352412 -0.00028945693180891367 4.835488447109778e-5 -3.9475557299978915e-5 -1.1616871474202832e-5 -4.96396007877268e-5 0.00012306768779647833 4.197221856642964e-6 -4.624281269958799e-5 -8.030450250249801e-5 4.826406233017754e-5 -0.0001671188931725423 -7.40688831358649e-5 -1.0386833550242405e-5 -8.83573126362998e-6 -6.379858520109823e-5 4.4877329615460635e-5 0.00016540978843816858 0.0001498066428537045 -2.001502458757745e-5; -9.596363779848974e-5 -0.00012508840846120296 -0.0001289399948319274 -0.00010694619377600917 0.00012453677518067452 1.6019914616568087e-5 -0.0001422967321587589 -2.2859195817631667e-5 0.00013591859933228712 -0.00011637645330040048 -0.00014550580781642794 -0.00017860341140754602 -0.00014177753437527184 -3.037463430075614e-5 -3.6449873371638725e-5 4.7048779740079376e-5 -0.00012996260157080978 0.0001318664002582355 -0.00020261191785096142 0.00016181792343274756 4.0995153901246805e-5 6.786194753377442e-5 -0.00019707024294175722 0.00011713601306553226 -0.00016662351809202053 -8.236597168718239e-5 1.9016410439580535e-5 0.00016949820601155352 -8.902788213446414e-5 -1.203076463127814e-5 -0.0001334953263750565 0.0001232859507553263; 0.00014993552360384354 2.519936753661275e-5 5.883828236538329e-5 -9.52114900254783e-5 4.261632749977149e-5 -0.000130692413568697 7.91994058264268e-5 0.0002970768797414232 -2.2598546257441854e-5 -8.924348765585023e-5 -0.00013323137353008163 3.177896325642461e-5 -2.157937377059794e-5 3.137565692587395e-5 3.668528699482192e-5 -2.145836731951188e-5 -0.0001265079812409486 -4.9100730739966917e-5 7.405438702254211e-5 -9.352286395450469e-5 2.3227857690570333e-5 8.231775756858678e-5 1.1989128887860477e-5 -7.155912715542486e-5 -9.394526239783885e-5 -1.3751622992359416e-5 0.00013123142309187525 -5.661219291222643e-5 -0.00011263192243940612 7.619899008123784e-5 -0.00010968787353785127 0.00015392880197314982; -5.277560166457134e-5 -8.266633533871827e-6 -8.204602244682342e-5 1.1696222197257385e-5 -0.0001161196666481102 -6.611837642873311e-5 -4.0094386070498744e-5 -3.624716958309757e-6 -6.24525816356893e-5 1.9085642913460653e-5 -1.669743841654212e-5 -3.111510858893154e-5 9.148769850705435e-5 5.920576978217368e-5 2.3993234487486698e-5 -1.1249540596568465e-5 1.7131136980359793e-5 -8.653700542078072e-5 -2.811084386218881e-5 2.0334149946327563e-5 0.00016056490522435902 -5.1152434871770625e-6 -6.031115817046344e-5 6.13177365329908e-5 -6.660504340567923e-5 -3.984526455776306e-5 -1.2083465372316875e-6 -1.5588013132465724e-5 -9.202085829480594e-5 -5.0538372027469184e-5 -0.00010003742835858411 0.00011182484404975143; -0.00016497268263063715 2.7380302692293114e-5 2.8514235220610535e-5 7.296449728729338e-5 0.00014539101646934285 8.736990383314024e-5 1.145943059073174e-5 4.9861367353701436e-5 5.200252160849637e-5 -0.00010968497264710658 4.4829598830081456e-5 0.00010955600844881118 7.632248269345161e-5 -0.00010345222728372835 -0.00024888420631923454 2.6867080289057936e-5 9.862338736840231e-5 2.0105931236875232e-5 -7.350112611473745e-5 0.00014520494112931908 8.140099196881338e-5 -9.65735224033742e-5 -7.194508254341766e-5 -7.208738572242933e-5 0.00013204431996401933 -0.00016867573760262955 -0.00012579198557231068 0.00010297000806732747 4.334189105267047e-5 1.3507149726333943e-5 2.7603434484458956e-5 0.0002007409722702568; 0.00010661927129946436 7.530856286281607e-5 -0.00013877941770373015 -0.0001232210774662549 -4.897934510266619e-5 -2.9269290699844005e-5 -7.961133038519054e-5 -3.0410308191889386e-5 -0.0001988130529400541 0.00033038463201479784 0.00020708568676015203 -2.0753817317411676e-5 -6.933336359109677e-5 -0.00010245596737262078 -0.00014772819073299043 0.00013550328780071632 -0.00015538129033414863 4.262376800285486e-5 -4.2729224752506615e-5 -0.00014039008188886792 -8.774609838040728e-5 -3.448417142975268e-5 0.00014315296770179482 8.939355806320687e-5 1.1742961340828596e-5 1.7235729081917225e-5 0.00017430385972944814 -7.892022355120845e-5 3.2285707118401525e-6 0.00011609341212421435 2.870397705807129e-5 1.452659083840056e-5; 0.0002448344456394233 9.16444185901274e-5 0.00011361734492367154 -4.0212780858721407e-5 -9.29679130364438e-6 0.00018361660464563662 -0.00013270943612895834 -8.484220610569486e-5 -1.159246325129393e-5 0.0002660719763849876 2.6531153581004197e-5 -2.6393406815254555e-7 -3.369964919276674e-5 -3.549578845137377e-5 -0.00016327302740991195 -5.878573246600732e-6 6.405491863973777e-5 1.1888419849698829e-5 -6.20357315159354e-5 0.00013758696137359844 -0.0001003612852758756 0.00012525485350180874 -2.6700350771875137e-5 9.29634623941073e-5 -8.0792255302517e-5 -1.077728042548278e-6 2.5217177481531343e-5 4.5759421355467044e-5 -0.00014652323456118058 5.5353939260965754e-5 8.860845979201114e-5 -8.12615520715258e-6; 4.158640048813376e-5 -5.929390418343577e-5 -4.483855546452154e-5 0.0001020744592952388 -0.00017693481013879747 -1.948769486801686e-5 4.2503902381284987e-5 -8.450145383839088e-6 -0.00017733759260040358 -0.00011086274853909849 -7.769034806127672e-6 5.047707135790432e-5 -0.00011268929766254894 0.00011232147416563007 0.00015028582756174792 -3.698856845034162e-5 -0.00012929221803092142 7.00834488216642e-5 -0.00020616610082101926 -1.795586204566435e-5 -9.015401984101131e-5 -4.5966816383885115e-5 -5.1161602737860735e-5 9.028823741941905e-5 0.00012278016119494556 5.168707402304006e-5 3.656985690582123e-5 -9.11940452223899e-5 -8.19091375032635e-5 -0.00017651024346009676 -3.0025842630447366e-5 9.89472236088369e-5; -2.446300780705251e-5 7.231850983386087e-5 0.00011679566747785613 0.00013238594916926313 -0.00014836834228320142 -5.013527109825485e-5 0.00011814129035498102 3.180601057989833e-5 0.0001823581719379677 0.00012918641956800465 -8.561802454481374e-5 1.849047518204055e-5 -3.790730400783099e-5 4.369001058667839e-6 2.2752194270302777e-6 -7.086350300072079e-5 -1.647165985034837e-5 -2.773546978845204e-5 0.00010650428381793197 -0.00014994117148874332 3.365841845676948e-5 -0.00012021470095918284 -4.9847150452691056e-5 9.15698569458719e-5 8.858717270938557e-5 7.118473373863807e-5 0.00012914923942459607 -1.939787174181136e-5 1.040001362315513e-6 -0.0001752596849966983 -5.130587809092386e-5 -7.142303141721124e-5; 0.00013407250387251698 -4.164436088753763e-5 -7.102031943500211e-5 -9.079897633578084e-6 0.00011299033664988615 0.00011843265108187867 0.00015423479360183583 7.413592991075098e-5 -0.00011318621821112656 0.0001826286070107685 -9.939840222353982e-5 0.00016199955004852625 -5.780917402006389e-5 -2.6416048992472537e-5 6.604099778346437e-5 0.000161717970488887 0.0002350908953056894 -2.6916534831921814e-5 -7.718288091120705e-5 -9.129954107913593e-5 -8.65458596555361e-5 2.8497131072849347e-5 -4.319257916644108e-5 9.616268377874476e-5 -0.0001285677963343245 0.00012247475840274674 9.142225914990651e-5 -8.431700828361873e-5 -7.996550541714029e-5 8.776439515065873e-5 2.0756807291833828e-5 -0.00013359875722666377; 0.00017546112722931266 1.5064267801157913e-5 -8.951604321618691e-5 -8.759174887808552e-5 -9.316600325006994e-6 -3.6375395565020026e-5 -0.00011132286671335915 -1.6014966957930853e-5 -6.89273678187299e-5 -5.811687004887723e-5 -1.1126903144660496e-5 -2.9166595039625444e-5 8.889669263908083e-8 8.402141030393402e-5 6.236885893604551e-5 1.6229558798935383e-6 4.326373674422358e-5 -4.95052144168836e-6 8.778268021279862e-5 -0.0001250292138035418 4.4691337835784687e-5 0.00021889277341988927 5.959680639653799e-5 4.560954913478222e-5 -8.575793107291901e-5 1.3630587647015867e-5 7.57215799342817e-5 -2.8460861711834102e-5 0.0001440866450241731 -3.206473628659839e-5 3.925096244648927e-5 -5.7212457603502366e-5; -9.355507016367296e-5 6.89946687249567e-5 -2.153842932089613e-5 4.922351264908002e-5 3.6968122474392745e-5 -4.500912366308906e-6 -4.317293116725394e-6 -2.4658889271774728e-5 -3.175377376679345e-5 0.0001916145852352585 0.00010987345787511241 1.1927586017594683e-5 -0.00013939198059447817 -0.00014043362123844234 -1.8772523146578154e-5 -5.532403504146344e-6 -2.0088341882326815e-5 7.219168084488278e-5 9.843360796152589e-5 0.000202699201070247 -2.377104054390442e-5 -2.3228528773285742e-5 -0.0001422333730100558 -5.638295577442294e-5 -5.8367706044335016e-5 -2.0755663610869286e-5 -4.768853203346029e-5 -9.815122650577865e-6 0.00011634439626290467 5.00849933065565e-5 -0.00012192576057928637 -2.9444168360312048e-5; -3.2296286772014864e-5 -0.0003223563133722838 8.651513522651704e-5 -3.502150397236948e-5 8.681662180622979e-5 -9.230343388464864e-5 -1.2037047635959625e-5 0.00032374048638318746 -0.000184013554498796 -1.0040553970447453e-5 -0.00011541398529500394 -7.617537889003765e-5 -0.00010099567120037085 -4.3647395468559075e-5 -0.0001042462270923644 8.133109546195324e-5 -3.262580034044576e-5 0.000176733488078913 -2.6567884922927933e-6 -1.803703602474111e-5 -3.652682684316647e-5 0.00011784576523511636 0.00011510107757219442 8.776073551433185e-5 -0.0002531674861891764 4.104356012414453e-6 8.790245661673937e-5 3.3605138940562434e-7 0.00011126549830821985 0.0002038116191762877 0.00020698109907260216 0.00012105730016644334; 4.108963738979084e-5 0.00010275266876413849 -7.080176684104336e-5 7.351481516814569e-5 -4.6710594605269886e-5 -9.045267316543646e-5 9.679939293272337e-5 5.442663868160124e-5 4.064782016189182e-6 3.1385878548239057e-6 6.579733789401974e-5 -6.6019960366293235e-6 -8.464196972344719e-5 -0.00010232977175185249 -3.1195924665382154e-5 -0.0001622279068295224 -3.391258533959057e-5 -1.234958100105354e-5 -2.0627536204143933e-5 -6.413040492482343e-5 -7.281644312463135e-5 8.235019239799377e-5 5.169882489834739e-7 -5.224039151364748e-5 1.0171849763696925e-5 -4.2752408179536435e-5 9.169333688184964e-5 9.854611752523321e-5 0.00013218559805971145 -4.2521407439222736e-5 -0.00022805263680204585 -4.9160085753845055e-5; -5.6669769617850165e-5 7.515994853910406e-5 -4.004309688590072e-5 -2.7037886537265473e-5 0.00019994002410547796 1.3312928205585856e-5 -0.0002970315278054922 3.5416383399742545e-5 -0.00013207741961943 0.00021403144397644334 3.307065468269589e-5 -0.0002739150828694675 9.344085752836467e-5 0.00010853050240419776 3.610413601726708e-5 -6.469435414893444e-5 -2.5806587144155344e-5 2.5862147014585023e-5 2.9741478430684073e-5 1.531729263823423e-5 0.00019918057420162575 0.00010390640578570155 0.00011809354087908003 6.856625019444117e-5 0.00010362080262146898 -0.00014306871402205943 -3.477605299686999e-5 -0.00019493952517060596 -6.9609409036468e-5 -9.621367884229911e-5 -0.00012404317625461196 -1.602685973758589e-5; -9.277252423789792e-5 -9.011548297812459e-6 -0.00013463414246663084 -4.419957725777211e-5 -7.136039233746476e-5 8.311111616690305e-6 4.7487540416376845e-5 -7.683434990913846e-5 8.140750526126992e-6 2.846807807354676e-5 3.3955662369825186e-7 -0.00024486238284018115 0.00011382763738063001 -0.0001265642905652377 7.118660402037524e-5 0.0001267342912454268 0.00012284647060578275 0.00010210732112891083 0.00016327068702803583 8.422263168749981e-5 -3.526664398823657e-5 -7.293103078799312e-6 0.000229685628130296 1.6392222152991025e-5 7.023030308328062e-5 6.113578534705173e-5 0.00011691510268274164 -1.5028966243061635e-5 0.00010038343576734212 -0.00025271104563007505 -6.034943829736109e-5 4.287340163767779e-6; -2.414374438606768e-5 -0.00014870470448888595 -6.718106277778395e-5 -4.876720768633281e-6 0.00020293954758878255 -0.0001543814648271367 -0.0002268450113574755 2.7309010246857767e-5 8.699906930766452e-5 5.95254826380076e-5 -0.00011120752577928931 8.581559367924296e-6 0.0002328950579881085 7.960954854703847e-5 1.8072898737480124e-5 3.243819838596225e-5 -0.00020669200575005846 -7.701763360321698e-5 -9.060948622439705e-5 -9.144905897349763e-5 -0.00011102393154081275 -2.5028468090135837e-5 6.0902843242249194e-5 -0.0001149078668190972 2.3374201472932047e-5 -1.3205052428505403e-5 5.9700600385865916e-5 0.00016526583231958906 0.00010496665373418118 -1.2139674395562587e-6 4.0451792583068156e-5 -2.1289153578683136e-6; -3.5712696109309646e-5 -1.7355582378005947e-5 5.905270824637677e-5 -1.4130164919540915e-5 7.444920935866555e-5 3.6512460945751124e-5 8.5921306316373e-5 -0.0001424845801297455 -2.208543867838421e-5 7.953089822747769e-5 -5.9490431005025135e-5 -0.00012987357025073026 -3.6779842622749436e-5 -5.669442715915178e-7 0.00010415922894804103 9.131862622650675e-5 -8.022668465074829e-5 8.57883548503313e-6 3.305603910902018e-5 2.397146033089099e-5 6.246708960004589e-5 -6.306347189041775e-5 0.00010320312446180604 0.00018468642098795005 -0.00011329611573402728 0.0002585572541361986 -1.6616439671833268e-5 0.00012975534206687735 -0.0001276667868582289 5.8550667171025545e-5 4.009696510237083e-5 0.00017818239882086976; 0.00016829533826784074 -1.4561234524124862e-5 4.581285004491915e-5 -7.714496193696433e-5 -8.403443523458454e-5 7.035030405492608e-5 1.7360141109272594e-6 5.267257195187894e-5 -6.242407714491005e-5 -4.435053658564044e-5 -2.5432217329611115e-5 2.0776145289277868e-5 6.483782756358097e-5 2.0048034753915313e-5 0.00022817310794011518 -7.220211289134202e-5 -3.7206106457245426e-5 3.7660411871633817e-6 0.00011254051908143963 3.512915402642605e-5 -3.765999439587587e-6 4.5684422117078486e-5 9.025522861153786e-5 9.269602135281084e-5 -0.00011531780062252102 0.0002051359854940419 6.322220117543365e-5 0.00015754010503404712 0.00018791775194220233 2.282619182186736e-5 1.99388226301278e-5 -0.00013041630761475662; -0.00011267388189916532 -0.00017550201600424693 0.0001429343871724995 8.707317379234913e-6 -7.090977605438843e-5 0.0001037217850400378 7.472570091704385e-6 -6.217458148493969e-5 2.4613744885944284e-5 8.288478120922028e-5 -0.00021538034643053475 -4.1284903358977825e-5 7.270151100345984e-5 3.927131247583765e-5 4.043636063688572e-5 -6.815621899913323e-5 0.00011443196554428879 -3.753373247927395e-5 0.00020578518588555082 9.827469030388662e-5 -3.9607063895158894e-5 -2.897827043379803e-5 8.708977539318227e-5 3.3598601208003235e-5 5.6318964839381315e-5 9.626364472315646e-5 -0.00017889721065766275 3.037988854449228e-5 -2.1165640004977652e-5 9.578021554735501e-5 2.634230141230409e-5 -0.000102104178308579; -9.747911058566574e-5 -9.824374824728497e-5 6.182607507600006e-5 2.741622160629435e-5 8.396317601167495e-7 -1.3586880917677944e-5 -0.00018139937269933068 7.52126074492563e-5 -4.925617354656578e-5 -5.554651539041598e-5 -2.3962625657022046e-6 8.84932200808463e-6 0.00017566545324901762 -3.55843513189255e-6 -0.00026990867586772324 -7.526627410170209e-5 -7.594484446072089e-5 -0.00012490568001721403 -9.753391309841706e-5 0.00011939654400138953 0.00010432977473629062 9.155272750450489e-5 0.00015418848529025656 -2.3566397025390643e-5 -2.8552299731200977e-5 0.0001986099236508355 -0.0002455746430807273 9.127610287195927e-5 -0.00014599018282224143 -1.2169489813643878e-5 0.00014053196395779824 -7.839245477425061e-5; -4.18228932861653e-5 -4.412842233741364e-5 0.00010113647524540334 4.540840247920307e-5 0.00010843120483025427 -8.401975126197186e-5 -5.975339241011963e-6 8.435701104154621e-5 -9.684163294457364e-5 0.0001725531206144521 4.579389362753624e-5 -2.5694469036729176e-5 4.408966790339466e-5 2.7829439681797495e-5 8.485323135082524e-5 -7.179081646916672e-5 0.00010441433780636271 -3.01224112948917e-5 1.7141614557320995e-5 4.174881407800846e-5 -1.6167965444588275e-6 7.983458137878719e-5 -2.8807127342030235e-5 4.483689056053445e-5 -0.00015252562068360492 0.00012976136579220807 0.00013580151655753455 8.396276327823863e-5 1.8441733595519257e-5 1.5694763100898913e-5 -1.1212333879822678e-5 -9.821078626863575e-5; -7.460693910797608e-6 5.131297524512356e-6 -2.8845147613740834e-5 -9.179571425757677e-5 0.0001467459522189527 -7.640680924506367e-5 0.00010162952893895393 0.00013343019308539078 -0.00018006918181181394 -8.304328294452824e-5 0.00014508438543433703 -0.0001389530072685707 9.385887896655199e-5 9.852214206494054e-5 2.8236225983869843e-5 9.203051085368603e-5 0.000111149202706475 -0.00011009173541079655 3.183084190259484e-5 -8.582926166691102e-5 1.867749487591396e-5 -4.8204790078869686e-5 5.1178377174245656e-5 -0.00013448207683377272 -2.3756169468479537e-6 -5.0448389816845967e-5 0.00014723243729695095 0.00011870564121072769 8.41505250662478e-5 5.235856660310847e-5 2.534948616248767e-5 0.00019204794152191134; 4.72860559839728e-5 0.00016551491388422366 -0.00015585716356686366 1.3900160056855044e-6 -8.445710738119933e-5 0.00013474878904712212 0.00022678575295326245 -8.635645246063478e-5 7.072065308061162e-5 0.00011644623408396473 -4.988945839973326e-5 5.838793082192591e-6 6.214905066418223e-6 -0.0001699941891401568 8.469577902403725e-6 -0.00017423554856027678 6.6469559770166255e-6 6.887113375891754e-5 -0.00011640342892310873 0.00017579111262061345 -8.189092806237964e-5 5.128761075506822e-5 9.974872626666749e-5 0.00015655342248105167 -0.00010806949614011941 0.000266566731066673 7.81050661189922e-5 -0.00013113360511302291 -3.276276367064789e-5 -0.0001420349104718799 -0.00028136960064689914 -0.0001377782569888528; -6.19424204667023e-6 -4.254525370520969e-5 2.1523729487347685e-5 -3.093392038460146e-5 -0.00010342891360375912 -2.430362570126154e-6 -9.976585857254064e-6 9.69113778814415e-5 0.0001673167888333659 -6.298281874089077e-6 5.964123087727384e-5 0.00013310165636035303 5.9778426334075705e-5 -2.1210965456314987e-5 -5.56078749973997e-5 -0.00025870081768113564 -0.00018639171062711583 -0.00018147189726329613 0.00011852991414746208 -0.00011910803473763224 0.00010227605784569424 -0.00021721447151827264 -5.245047319073337e-5 0.00012090850472707307 0.0001046364439793507 -1.791113859833212e-5 -3.294612648057544e-5 0.00012033381321491709 6.080049110376423e-7 -7.778272734465096e-5 -5.620470724859301e-5 -7.891178134337966e-5; 3.1509959624870227e-6 4.3562971495543483e-5 -0.0002085419130709635 -5.052504484426444e-5 -5.0488399483738646e-5 8.903777570578951e-5 7.39267104107275e-5 -6.678124400439118e-5 2.1989668039013473e-5 -9.63042787590627e-5 -9.205474116258597e-5 7.026361072185399e-6 2.7841415157868762e-5 -0.00010408293956076716 -4.8048083317713864e-5 -0.00011413765814738516 3.6348142839260465e-5 0.000164399459845142 -4.6706032935771646e-5 -3.4684572074338376e-5 -0.00010621993924375771 0.0001586460400174701 0.00010618827736818897 7.242733199730536e-5 0.00012010060681844092 -3.4031351151652433e-5 0.00010640553746255066 -1.935689606413703e-5 -2.820728286834847e-6 0.0001788332135218571 8.212309849098356e-5 -5.907947188477346e-5], bias = [-6.409043533035e-10, 4.3837161760634414e-9, -2.8172184298987375e-10, -2.2489002251568046e-10, 1.0361683568520363e-9, 3.5497040178013125e-9, -2.458039175501166e-10, -4.481628937113344e-9, 2.5795007708494766e-10, -1.8732799364386683e-9, 1.4198869539159253e-9, -1.3047723323251478e-10, 2.4882507559132563e-9, -2.8604571978814757e-9, 9.190132654843579e-10, 2.7548347839604873e-9, 1.0455653994475551e-9, -1.1594452919880257e-10, 1.3798491478212556e-9, -1.3805300356278334e-9, -4.826126439693617e-10, 1.2395160424597192e-9, -1.0692843114560348e-9, 3.015310163295056e-9, 4.184881608748057e-9, 8.482904456123831e-10, -1.7021686321453207e-9, 2.738056199433208e-9, 2.4890994632709928e-9, -2.579453896101269e-10, -1.7872985831184067e-9, 6.236403626104451e-10]), layer_4 = (weight = [-0.0005653352680247515 -0.0005058386425647957 -0.0006851167375934766 -0.0006338841968247037 -0.000810085085567813 -0.0008391747349860614 -0.0006867316231296072 -0.0006740625775751847 -0.0006590083327691357 -0.0005798030663795085 -0.000691994594901225 -0.000705463932577439 -0.0007611984749287035 -0.0005866962089886381 -0.0007082163927918892 -0.0005511536347655217 -0.0006730617445470481 -0.000828201058470692 -0.0005810112725802559 -0.0008830613711361249 -0.0007334273545931898 -0.000617082465966667 -0.0006952667291432849 -0.0006585163331080974 -0.0006923532988620416 -0.0006728730904358961 -0.0007389797534461505 -0.0007113114645163474 -0.000797360816927932 -0.0007573258814180911 -0.0007447001989769274 -0.0007668075165279793; 0.00022374716554960123 0.0003569595877003191 0.0001665827097223377 0.00020601200099210933 0.00028884857438686665 9.528743402180254e-5 0.00033455203671330623 0.0003537037429237981 0.00021216570296620554 0.00017559635029060913 0.0005267705990089065 0.0003734393529161941 0.00016956308729926142 0.00023726379892102315 0.00019211655342859173 0.0003195693483321594 0.00013360685135195834 0.00022073471752811921 0.00022688152057051084 0.00013066300681711156 0.00031933902672398134 0.00010851422438717735 0.00016369466369103958 0.0002336408981910503 9.142371103313082e-5 0.000263484665425977 0.00017794272957930672 0.0003231086846805694 0.0003054929299541275 0.0001612414329537182 0.00020294243204574159 0.00017601329442265266], bias = [-0.0006961303354409823, 0.00022444581457892654]))Visualizing the Results
Let us now plot the loss over time
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")
lines!(ax, losses; linewidth=4, alpha=0.75)
scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)
fig
endFinally let us visualize the results
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
s3 = scatter!(
ax,
tsteps,
waveform_nn_trained;
marker=:circle,
alpha=0.5,
strokewidth=2,
markersize=12,
)
axislegend(
ax,
[[l1, s1], [l2, s2], [l3, s3]],
["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
position=:lb,
)
fig
endAppendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.12.5
Commit 5fe89b8ddc1 (2026-02-09 16:05 UTC)
Build Info:
Official https://julialang.org release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 4 × AMD EPYC 7763 64-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-18.1.7 (ORCJIT, znver3)
GC: Built with stock GC
Threads: 4 default, 1 interactive, 4 GC (on 4 virtual cores)
Environment:
JULIA_DEBUG = Literate
LD_LIBRARY_PATH =
JULIA_NUM_THREADS = 4
JULIA_CPU_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0This page was generated using Literate.jl.