Training a Neural ODE to Model Gravitational Waveforms
This code is adapted from Astroinformatics/ScientificMachineLearning
The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl
Package Imports
using Lux,
ComponentArrays,
LineSearches,
OrdinaryDiffEqLowOrderRK,
Optimization,
OptimizationOptimJL,
Printf,
Random,
SciMLSensitivity
using CairoMakieDefine some Utility Functions
Tip
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector
function one2two(path, m₁, m₂)
M = m₁ + m₂
r₁ = m₂ / M .* path
r₂ = -m₁ / M .* path
return r₁, r₂
endNext we define a function to perform the change of variables:
@views function soln2orbit(soln, model_params=nothing)
@assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
if size(soln, 1) == 2
χ = soln[1, :]
ϕ = soln[2, :]
@assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
p, M, e = model_params
else
χ = soln[1, :]
ϕ = soln[2, :]
p = soln[3, :]
e = soln[4, :]
end
r = p ./ (1 .+ e .* cos.(χ))
x = r .* cos.(ϕ)
y = r .* sin.(ϕ)
orbit = vcat(x', y')
return orbit
endThis function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d_dt(v::AbstractVector, dt)
a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
b = (v[3:end] .- v[1:(end - 2)]) / 2
c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
return [a; b; c] / dt
endThis function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d2_dt2(v::AbstractVector, dt)
a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
return [a; b; c] / (dt^2)
endNow we define a function to compute the trace-free moment tensor from the orbit
function orbit2tensor(orbit, component, mass=1)
x = orbit[1, :]
y = orbit[2, :]
Ixx = x .^ 2
Iyy = y .^ 2
Ixy = x .* y
trace = Ixx .+ Iyy
if component[1] == 1 && component[2] == 1
tmp = Ixx .- trace ./ 3
elseif component[1] == 2 && component[2] == 2
tmp = Iyy .- trace ./ 3
else
tmp = Ixy
end
return mass .* tmp
end
function h_22_quadrupole_components(dt, orbit, component, mass=1)
mtensor = orbit2tensor(orbit, component, mass)
mtensor_ddot = d2_dt2(mtensor, dt)
return 2 * mtensor_ddot
end
function h_22_quadrupole(dt, orbit, mass=1)
h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
return h11, h12, h22
end
function h_22_strain_one_body(dt::T, orbit) where {T}
h11, h12, h22 = h_22_quadrupole(dt, orbit)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
h11 = h11_1 + h11_2
h12 = h12_1 + h12_2
h22 = h22_1 + h22_2
return h11, h12, h22
end
function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
# compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2
@assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"
h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
@assert mass_ratio ≤ 1 "mass_ratio must be <= 1"
@assert mass_ratio ≥ 0 "mass_ratio must be non-negative"
orbit = soln2orbit(soln, model_params)
if mass_ratio > 0
m₂ = inv(T(1) + mass_ratio)
m₁ = mass_ratio * m₂
orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
return waveform
endSimulating the True Model
RelativisticOrbitModel defines system of odes which describes motion of point like particle in Schwarzschild background, uses
where,
function RelativisticOrbitModel(u, (p, M, e), t)
χ, ϕ = u
numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
denom = sqrt((p - 2)^2 - 4 * e^2)
χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = numer / (M * (p^(3 / 2)) * denom)
return [χ̇, ϕ̇]
end
mass_ratio = 0.0 # test particle
u0 = Float64[π, 0.0] # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4) # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize) # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, eLet's simulate the true model and plot the results using OrdinaryDiffEq.jl
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)
axislegend(ax, [[l, s]], ["Waveform Data"])
fig
endDefining a Neural Network Model
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but in case you do use them, make sure to mark them as const.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,
const nn = Chain(
Base.Fix1(fast_activation, cos),
Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-0.00011062599; -8.5733096f-5; 8.0714f-5; 8.0979764f-5; -7.672317f-5; 0.00013390114; 6.6576204f-6; 4.0586114f-5; -1.03308885f-5; -9.856098f-6; -0.00014221299; -0.00016849031; 6.601813f-5; 0.00014108652; 0.00012349701; 3.801642f-6; 1.1083842f-5; 1.6944403f-5; -8.012122f-5; 0.00020991851; 5.8674127f-6; 1.1830696f-5; 2.250284f-6; -0.0001119848; 6.614451f-6; -0.00014941292; 7.001147f-5; -7.381001f-5; 0.00015550862; 3.1217554f-5; -2.3942983f-5; 0.00015155856;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-9.295025f-5 6.121096f-5 8.502578f-5 -5.842799f-5 7.659016f-6 0.00015339178 -9.3544884f-5 -9.63586f-6 0.00019500912 5.1449344f-5 0.00011535053 -0.00020734436 -0.00015997737 -8.6493324f-5 0.00016354781 -7.425347f-5 5.6096473f-5 -0.00010469141 3.3293913f-5 -8.379442f-5 -6.668655f-5 9.467679f-5 -5.6280922f-5 0.00013794476 8.1124155f-5 -0.00019931133 -3.2923316f-7 -0.00010414512 -0.00010282014 0.00011126966 6.633078f-6 -0.00011279382; 3.501202f-5 -4.236184f-5 0.00023467618 0.00011905429 -8.1853075f-5 3.5545705f-5 -9.5798096f-5 -1.414002f-6 7.70363f-5 -6.0257928f-5 7.161419f-5 3.0624342f-6 2.8463633f-5 -0.00014921257 6.905816f-5 -0.000117804084 -6.23111f-5 -5.022439f-5 0.00012546573 -6.972989f-5 -6.261286f-5 5.999882f-5 -9.212584f-5 -0.00012743124 -2.7437753f-5 4.813104f-5 2.6692353f-5 -5.3407533f-5 -0.0001421174 -3.7501955f-5 4.510648f-5 -6.675427f-5; 9.388355f-5 -8.5529f-5 1.421789f-5 -0.00013916042 -6.082237f-5 -0.0001273978 -9.090077f-5 0.000100480145 0.00019333004 -9.198423f-5 0.00011467622 0.00016561006 8.957101f-5 0.000100895086 -0.00014408595 -4.0836665f-5 -1.8079721f-5 4.8282927f-6 3.7670972f-5 5.513295f-5 -0.00015045729 2.070122f-5 6.462496f-5 0.00010790955 6.0074435f-5 -3.1447384f-5 -9.836491f-5 -0.00014720795 -0.00023096884 5.7770256f-5 -0.00013819191 -8.9625115f-5; -1.16621295f-5 -4.165209f-5 -0.00014866143 -8.193959f-5 5.9778482f-5 -0.0001272149 -3.6873596f-5 -7.976831f-5 -6.2104016f-5 9.568702f-5 -0.000109575856 0.00012060223 0.0001576338 -7.0582835f-5 8.443461f-5 5.205502f-5 -0.0001359865 -5.9990175f-6 -3.6325524f-5 0.00019400624 0.00020900587 -5.8377278f-5 -0.00011205356 -0.00016133064 1.7905935f-5 0.00027469985 -5.6899928f-5 9.615125f-5 0.00020551 1.50297665f-5 0.0002015446 4.2955442f-5; -6.4730084f-5 7.305937f-6 8.0600716f-5 -3.5063014f-5 8.016925f-5 1.5708165f-5 -5.731141f-5 -4.650733f-6 4.7506215f-5 -6.236255f-6 0.00028957913 -8.2624254f-5 0.00012305794 -4.3587846f-5 5.5094744f-5 8.7724846f-5 -0.00014173985 0.0001800067 -2.5375557f-5 -1.4361358f-5 -7.592933f-5 3.60623f-5 0.00013697374 -7.277859f-5 8.2918144f-5 -0.00014920936 -0.00010799371 -4.3408258f-5 0.00020627811 -0.00016269836 -5.521273f-5 9.739278f-5; 3.3565255f-5 5.3738986f-5 1.6735534f-5 1.5077321f-5 0.00023735811 -5.0975537f-5 -0.00011933445 0.00015816886 4.487118f-5 1.2497004f-5 -3.0371127f-6 4.5330216f-5 -0.00018522037 -7.8367615f-5 -6.0969825f-5 -1.9538406f-6 7.005313f-5 -6.3171436f-5 9.5614654f-5 1.0909118f-5 0.000113791764 2.0711417f-5 0.0001975491 2.0094554f-5 -7.307977f-5 -0.00016553681 -9.911527f-5 3.1533862f-5 -6.648831f-5 -3.2795237f-5 -0.000110719266 6.872789f-5; -4.0572766f-5 0.00018409215 6.1121784f-5 8.345238f-5 2.4154506f-5 -1.6045979f-5 -0.00014043762 -3.8452283f-5 0.0001508398 -5.896191f-5 -0.00012463437 4.5907687f-5 7.803658f-5 0.00014257719 3.8047652f-5 -1.1101596f-5 5.130622f-5 4.8647933f-5 7.7494056f-5 2.6490068f-5 -0.00015397713 -0.000109747554 -0.000116576 -1.38883315f-5 3.4452725f-5 -3.4706678f-5 -6.622261f-5 5.323961f-6 1.6541218f-5 7.283397f-5 -3.225397f-5 4.105441f-5; 8.957128f-5 -0.00016572197 0.00015862344 0.00019210811 -9.2856906f-5 0.00015172221 -9.6504096f-5 2.9101593f-5 5.7247304f-5 -9.891585f-5 -1.4837833f-5 0.00017964272 -4.1168198f-5 6.744171f-5 9.698955f-5 -0.00015328772 5.1077404f-5 -2.7475327f-5 -0.00011423463 -0.00012764269 -4.5389017f-5 -0.00019768052 1.7969416f-5 -4.1953903f-5 -5.5381035f-5 0.00015497951 0.00017659587 -0.0001073177 -0.00019066458 3.1408566f-5 0.00025642876 0.00019595228; 3.3648106f-5 -7.18424f-6 0.00016233092 4.940778f-5 -5.2615564f-5 -8.1935825f-5 -2.2463562f-5 -3.9520808f-5 -0.0001514371 6.760235f-5 -0.00015348969 -1.3958484f-5 0.00016209559 -6.720103f-5 -8.8949746f-5 0.00019551501 -0.00023603611 -5.8613405f-5 0.00014952876 -4.2462045f-5 0.00012097737 7.366127f-5 4.498226f-5 -2.803307f-5 -6.0533002f-5 -7.633808f-5 -3.0457019f-5 3.9367344f-5 2.928259f-5 2.9768293f-5 0.00015153794 2.742596f-5; -5.3979806f-5 0.00017630782 7.57485f-5 -9.290421f-5 0.00015476535 -0.00018847974 -2.2945222f-5 1.4023992f-5 -6.9009475f-5 0.000105291634 -4.6983158f-5 5.5510314f-5 7.9221645f-5 8.6868255f-5 -3.1757688f-6 1.31696515f-5 8.329642f-5 -5.6218472f-5 -3.4914465f-5 -2.9808103f-5 8.6298736f-5 5.9133683f-5 4.0445284f-5 -5.0724906f-5 -4.328658f-5 -7.2817806f-5 0.00012332301 0.00025270318 1.9555231f-5 2.3896186f-5 0.000120741446 -3.2727905f-5; 1.3778879f-5 0.0001727275 6.726267f-5 -0.00012180348 0.00012602318 5.021141f-5 0.00011486053 -1.6770562f-5 6.217809f-5 4.9483835f-5 6.7858084f-5 -4.3816617f-6 4.4061217f-5 6.405472f-5 -3.1340605f-5 4.574346f-5 4.417262f-5 0.00019557889 -1.7763745f-6 -0.00015498175 8.772362f-5 0.0001605415 -4.592417f-6 -7.961509f-5 -9.884098f-6 -0.00017678777 -1.5010851f-5 6.803548f-5 -0.0001253624 6.7170615f-5 -8.579084f-5 0.0002301844; -0.00014605782 0.00016117327 -3.9911733f-6 -5.651706f-5 -1.3043202f-8 2.4636322f-5 -9.180794f-5 0.0001727085 -0.00010280093 0.00011365875 -5.1469713f-5 0.000113110706 7.047993f-5 0.00016307793 0.00019447276 -9.468344f-5 -7.54893f-5 3.865218f-5 -0.00012947999 -5.147174f-5 1.869637f-5 -3.1421055f-6 3.9913437f-5 -0.00019338199 1.9590247f-5 -4.623005f-5 -1.781631f-5 -9.868818f-5 -6.0971495f-5 -0.00011558041 -6.5484855f-5 -0.0001294781; -2.6511396f-5 0.00012617247 -3.7048478f-6 -5.7799396f-5 -3.1321952f-6 -7.4464464f-5 8.816393f-5 -0.00010766935 0.00011405761 -9.749204f-5 0.00029193505 2.0313846f-5 9.216313f-5 4.9079994f-5 -0.00021901786 0.00013768725 -6.587706f-5 0.00012264098 -1.5883184f-5 -1.6606564f-5 -1.1046394f-5 -8.92588f-5 -6.418831f-5 -1.0839111f-5 -5.965298f-5 -1.4325025f-5 -6.894914f-5 0.00013277693 -0.00010833104 3.6633846f-5 0.000115710856 2.01233f-6; 6.901256f-5 -3.501283f-5 -3.3093267f-5 -0.00016759482 -0.000116985386 0.00014502452 -3.0716204f-5 -0.00016166325 -9.775906f-5 2.8465218f-5 0.0001094308 5.8769616f-5 7.564489f-5 -0.000116358126 -5.9451744f-5 5.0873143f-5 5.8516995f-5 -4.83106f-5 4.655924f-5 7.6903445f-5 -1.289386f-5 -2.3501008f-5 -9.2675735f-5 6.3503554f-5 6.2028186f-5 -9.5341056f-5 -4.354096f-5 2.2025832f-5 6.191948f-5 -9.159274f-5 -1.3408818f-5 -7.779252f-5; -2.5560243f-5 -0.00012498087 6.324995f-5 -1.9048992f-5 -2.006975f-5 -0.00015745303 6.433916f-5 -0.00016803053 -4.444682f-5 -3.842958f-5 -8.339827f-7 4.6642228f-5 0.00015577994 1.998809f-5 0.00012960617 6.44621f-5 -0.000102829894 -3.446933f-5 -0.00013929082 -9.280317f-5 -9.577131f-5 0.00016670927 -0.00017418277 9.281878f-5 -3.4790475f-5 0.000105538755 1.3261424f-5 -0.00010082788 8.926289f-5 -0.00015622628 -0.00021391474 -0.00010829681; 1.28782685f-5 0.00010857978 9.054971f-5 -2.5754765f-5 3.262942f-5 0.00021267931 -5.4572765f-6 9.156239f-5 0.00020975387 0.00015381096 1.3368687f-5 0.00018093974 3.7886835f-5 8.506997f-5 -4.255179f-5 -6.47533f-7 2.2175411f-6 7.2523806f-5 6.4504435f-5 7.280017f-5 -3.129394f-5 -3.7928953f-6 -4.597413f-6 -7.8519966f-5 4.493161f-5 6.490522f-6 -0.00011298398 0.00010392909 0.000108571105 -0.00017588069 0.00025359765 -0.00011499763; 2.525802f-5 3.0733902f-6 -3.4650446f-5 -5.933747f-5 5.0438957f-6 0.00029479753 7.272233f-5 -3.105818f-5 0.0001027109 -4.534297f-5 8.7309636f-5 -7.195959f-5 -9.174877f-5 0.00010211438 -0.00020541098 2.1261874f-6 -1.8039798f-5 5.370415f-5 -2.2488199f-5 -0.00015086451 -0.00014079383 0.000107845284 3.8461826f-6 -0.00015340232 1.9428721f-6 0.00021009566 8.320563f-5 0.00012905132 0.00011852019 9.7568896f-5 1.6800703f-5 0.00015577555; 8.1777245f-5 7.962621f-5 1.972223f-5 -8.370306f-5 4.487113f-5 0.00012044325 -0.00019229 -0.00015570421 0.000112018344 -5.3445932f-5 7.75401f-5 -3.0267924f-5 5.516633f-5 5.782704f-5 1.5480045f-5 -5.515792f-5 0.00011929896 0.00013956809 -9.3918396f-5 -2.7678821f-5 -2.3243487f-5 -5.221219f-5 9.694414f-5 -9.6637064f-5 -3.8126593f-5 0.00013917669 -0.00015574449 7.725843f-5 0.00011410993 3.2773525f-5 -9.828506f-5 6.8297755f-5; -2.6732134f-5 -5.477238f-5 4.679015f-5 3.0445713f-6 0.00014159812 -3.5901656f-5 5.2232965f-5 -0.00018263404 0.0001273993 -0.00016346277 9.146056f-5 -0.00023979123 -8.869842f-5 0.00013684829 -8.766403f-5 0.000111703936 -3.579819f-5 -8.1293416f-5 0.00020759284 -0.00014675574 2.4612174f-5 0.00012480737 -8.838496f-5 -4.7609592f-5 1.2618774f-6 -3.4996247f-5 -8.246872f-5 -3.849131f-5 -0.00021376068 -6.73016f-5 -3.484986f-5 -0.00012138895; 0.00014418802 -4.7342728f-5 -6.6959845f-5 -9.735f-5 2.4043704f-5 -0.00023295716 2.0859918f-5 -0.000117808784 0.00023542655 0.00013329109 7.271636f-5 7.488865f-5 -0.00015709517 5.2106734f-6 0.00012833376 3.9335882f-5 -0.00012806366 -0.00018140062 -3.2856035f-5 -7.95967f-5 -2.204764f-5 -1.09725015f-5 0.0003112077 -0.00011755654 1.1712763f-5 2.2837767f-5 0.0001942286 -8.7745364f-5 9.031872f-5 0.00014290574 -7.7450815f-5 1.5633207f-6; -0.00014968443 -1.7143306f-5 8.527949f-5 8.793071f-5 -0.00017844385 0.00021326116 -0.00019502021 -7.1055554f-5 -0.0001177322 -4.4215594f-5 1.4556766f-5 0.0001210549 9.858444f-5 -4.352874f-5 -0.00011841135 -4.169295f-6 0.00018187272 -0.00023867573 -5.913205f-5 -4.2589483f-5 0.0002002157 3.59499f-5 -1.5643189f-5 -1.4911418f-5 3.6662892f-5 -0.00010619918 -0.00021418581 -7.042595f-5 0.00013958821 -7.822104f-5 -7.9300095f-5 -0.00021558776; 2.243932f-5 0.00011392136 -4.8040343f-5 0.00021748875 3.882067f-6 -4.7365585f-5 0.00013206819 8.591736f-7 -7.431114f-5 -7.408343f-5 8.2682595f-7 -0.00013405975 -1.9140612f-6 0.00020236819 -1.166649f-5 -0.00014871884 -1.6820159f-5 -3.0894756f-5 0.00015423556 -0.0001895839 -1.1463784f-5 -3.5356235f-5 -3.73416f-5 -3.1969657f-6 8.948299f-6 -7.719844f-5 -1.3581514f-5 3.977815f-5 -5.948067f-5 4.705181f-5 -0.00022608115 -0.00017676818; -0.00015645521 -2.7478814f-5 -8.746608f-5 6.810443f-5 9.933508f-5 5.03144f-7 0.00015381901 4.695597f-5 -0.00024910082 6.152454f-5 -4.0812516f-5 -5.8148595f-5 -8.651044f-5 2.4505913f-5 0.000106243264 -1.5069206f-5 4.7592774f-5 -0.00021558648 -6.9693386f-5 7.848854f-5 5.2557658f-5 -3.345853f-5 -3.1961907f-5 -6.1848245f-5 -0.0001394649 -0.00014470251 -1.4073272f-5 -0.000107062304 -9.967692f-5 3.2313736f-5 4.3383785f-5 6.194195f-5; -4.860006f-5 -9.734154f-5 -0.0001940118 -9.164045f-6 5.0527742f-5 0.00013105717 -0.00016451394 -2.677188f-5 7.702753f-5 -4.9771796f-5 -4.7708596f-5 5.4493055f-5 6.806011f-5 5.2728625f-5 -5.725892f-7 -0.00016872004 -0.00020854267 -0.0001670507 -0.00016144563 1.8450044f-5 5.173619f-5 -0.0001055296 3.3585042f-5 1.6963559f-5 -1.7264583f-5 -5.2129057f-5 0.00014631913 1.3275742f-5 0.00013994679 4.9394883f-5 6.314905f-5 -2.1886397f-5; 2.7891672f-5 -8.055247f-5 6.7934365f-5 -0.00017445562 2.5493457f-5 0.00016704487 -0.00013655654 -8.132735f-5 -4.1070516f-7 -2.6571985f-5 -9.321977f-5 -0.000166451 0.00011331619 -2.9578745f-5 0.000100322985 -0.00020685772 6.183576f-5 -5.4126554f-5 4.8903334f-5 0.000167261 -9.2584174f-5 7.717036f-5 -8.7625034f-5 4.7845402f-5 4.6497888f-5 0.00020470102 5.5838485f-5 -3.9365066f-5 -9.254231f-5 -0.00010720353 -0.00017206385 0.00021492367; -3.8361126f-5 -6.599532f-5 -4.598129f-5 -9.98733f-5 0.00010175168 -2.7182547f-5 6.4785156f-5 8.1895916f-5 0.00012002821 1.0533086f-5 -4.255313f-5 2.7097924f-5 -9.365847f-5 -4.970598f-5 -1.5476142f-5 -1.6658176f-5 9.683685f-5 9.467716f-5 0.00017907679 -2.3655657f-5 0.00019751227 -0.00020019055 3.0232344f-5 -7.940426f-6 -5.2111784f-5 5.115775f-5 0.00011208519 0.000138763 8.328801f-5 1.0775888f-5 9.959968f-5 0.00013685478; -0.00016073188 -9.007675f-5 -1.109181f-5 6.628427f-5 -7.046842f-5 8.570501f-5 -0.00021498844 3.2354223f-5 -0.00011664107 -4.8339596f-5 3.8429793f-5 0.0001281177 2.5189971f-5 -0.0001483478 -0.00013285222 0.00016146652 -1.4767427f-5 3.3958553f-5 -1.3049998f-5 3.7359485f-5 -0.00017603431 1.7387236f-5 3.976641f-5 7.953151f-5 0.00012718313 -5.6128185f-5 1.8825574f-5 -0.00019502253 4.403703f-5 0.00020666957 -0.00018628396 -9.9553276f-5; -2.8659113f-5 0.00010881056 1.1462732f-5 1.1900602f-5 8.705696f-5 -0.00011612871 -3.434809f-5 7.766278f-5 7.107216f-5 -8.977764f-5 -0.00011490829 2.9903958f-5 -0.00014071558 0.000121870275 -0.00016721216 1.5103795f-5 -3.11765f-5 -0.00011506122 -1.7766975f-5 -0.00016404454 -0.00012861619 -3.0342666f-5 5.956928f-5 5.621412f-6 -6.0881703f-5 0.00019867577 -4.3616405f-5 -1.0013361f-5 5.4025735f-5 -6.822442f-5 -0.00013803568 -2.7804821f-5; 3.450164f-5 8.0511614f-5 3.6812744f-5 0.00024383662 -4.0454303f-5 -0.00018649589 -0.00021689838 -2.7792708f-5 -2.583708f-5 -2.1282867f-5 8.5144624f-5 -2.38362f-5 -2.5828955f-5 4.1841813f-5 -0.00014064375 -9.171275f-6 5.0977313f-5 7.921551f-5 0.00011400929 3.515925f-5 -8.598442f-5 -4.3608317f-5 5.267242f-5 8.0801665f-6 -0.0001331453 0.00012306812 0.00018642691 -0.00015589924 -5.6207402f-5 -9.518587f-5 -0.00016970842 -0.00013047757; 5.6434847f-6 -0.00024303925 -2.923487f-5 -0.00018058368 -1.1305992f-6 2.8617897f-5 1.9852941f-5 -5.1321445f-6 2.7097167f-6 -7.389978f-5 4.814572f-5 -7.034472f-5 5.532349f-5 0.0001438425 -2.6313164f-5 -7.160859f-5 -5.9899387f-5 -0.00019316349 2.1280957f-5 -6.7828005f-5 4.1597275f-5 7.551346f-5 -5.1802956f-5 -5.104013f-5 -6.37861f-5 2.6624035f-5 -5.15358f-5 -0.00012811247 7.717415f-5 -4.256658f-5 -4.0507248f-7 -2.9907425f-5; -3.245136f-5 8.120167f-5 7.112196f-5 -2.3054603f-5 -3.7621165f-5 7.061863f-5 0.0001413737 9.2056136f-5 1.5883063f-5 -2.0353624f-5 8.74175f-5 -0.000105295854 -1.4651918f-5 -0.00016896734 5.8400776f-5 7.858061f-5 2.0541964f-5 2.3797515f-5 0.0002459378 0.00014217362 -0.00017602411 -0.00010113374 -0.0001178778 -0.00011244504 0.00013541324 -0.00014108975 0.00012702649 -7.4670235f-5 5.2182248f-5 -8.029187f-6 -4.353831f-5 3.1841257f-5; 0.00020087116 -0.00013456323 7.4939686f-5 -3.533258f-5 -1.8958384f-5 -0.00018102289 -3.6836136f-5 0.000120879195 0.00016211529 -1.3993787f-5 -3.5013136f-6 3.8572773f-5 9.580833f-5 -4.6826593f-5 0.00012421807 7.328317f-5 -2.7532426f-5 -6.027037f-6 3.8032053f-5 8.195726f-5 6.058548f-5 -0.000116197305 -9.549745f-5 0.0002074817 -5.6234385f-5 -0.00011703478 -8.85015f-5 -0.00014934754 6.816251f-5 -0.00016949202 -0.0001486858 0.00010507445], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[-9.58949f-5 -0.00015206226 4.8822036f-5 -0.00013568829 2.0706379f-5 -0.0001389312 9.831056f-5 4.230304f-5 4.2833453f-5 -0.00017125398 0.00010365501 8.6296306f-5 0.0001220919 -9.0496535f-5 -8.2081206f-5 -0.00018763563 9.449259f-5 8.184782f-5 -9.487948f-5 0.00015968592 -0.00013838928 0.00012881361 3.3097465f-5 -7.153539f-5 5.1322345f-6 -5.7302164f-5 9.816883f-5 -0.00013536151 -7.476663f-5 -0.00010774492 5.6236484f-5 3.9106053f-5; -0.00016205004 0.00014085506 -2.8402883f-5 6.978863f-5 -2.0241057f-6 -0.00016049512 1.5812699f-5 -6.512637f-8 -1.1496976f-5 -1.903496f-6 -5.94116f-5 3.8122867f-5 -3.826372f-6 4.6094533f-6 0.0001072184 -8.0179944f-7 4.2702697f-5 9.67502f-5 -4.170707f-5 -6.9280875f-5 -3.0371531f-5 5.0547176f-5 0.000118900796 -6.3581596f-5 6.5623375f-5 0.00013374613 -9.641066f-5 2.5718273f-5 0.00013941042 0.00021900615 3.416923f-5 -0.00020621157], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))Similar to most deep learning frameworks, Lux defaults to using Float32. However, in this case we need Float64
const params = ComponentArray(f64(ps))
const nn_model = StatefulLuxLayer(nn, nothing, st)StatefulLuxLayer{Val{true}()}(
Chain(
layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
layer_2 = Dense(1 => 32, cos), # 64 parameters
layer_3 = Dense(32 => 32, cos), # 1_056 parameters
layer_4 = Dense(32 => 2), # 66 parameters
),
) # Total: 1_186 parameters,
# plus 0 states.Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses
where,
function ODE_model(u, nn_params, t)
χ, ϕ = u
p, M, e = ode_model_params
# In this example we know that `st` is am empty NamedTuple hence we can safely ignore
# it, however, in general, we should use `st` to store the state of the neural network.
y = 1 .+ nn_model([first(u)], nn_params)
numer = (1 + e * cos(χ))^2
denom = M * (p^(3 / 2))
χ̇ = (numer / denom) * y[1]
ϕ̇ = (numer / denom) * y[2]
return [χ̇, ϕ̇]
endLet us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
axislegend(
ax,
[[l1, s1], [l2, s2]],
["Waveform Data", "Waveform Neural Net (Untrained)"];
position=:lb,
)
fig
endSetting Up for Training the Neural Network
Next, we define the objective (loss) function to be minimized when training the neural differential equations.
const mseloss = MSELoss()
function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
return mseloss(pred_waveform, waveform)
endWarmup the loss function
loss(params)0.0007053861206901361Now let us define a callback function to store the loss over time
const losses = Float64[]
function callback(θ, l)
push!(losses, l)
@printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
return false
endTraining the Neural Network
Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
optprob,
BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
callback,
maxiters=1000,
)retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-0.00011062598787237051; -8.573309605692438e-5; 8.071400225159011e-5; 8.097976387952155e-5; -7.672316860398818e-5; 0.00013390113599591051; 6.6576203607765324e-6; 4.0586113755094395e-5; -1.0330888471779823e-5; -9.85609767665884e-6; -0.00014221298624759812; -0.0001684903108978923; 6.60181322018658e-5; 0.00014108652248982897; 0.00012349701137281317; 3.801641923926065e-6; 1.1083841854992381e-5; 1.694440288696859e-5; -8.012122270880183e-5; 0.0002099185076075341; 5.8674127103543805e-6; 1.183069616673551e-5; 2.250284069303035e-6; -0.0001119848020609021; 6.614451194762337e-6; -0.00014941292465658865; 7.001146877878753e-5; -7.381001341832223e-5; 0.0001555086200822363; 3.1217554351302014e-5; -2.3942982806983757e-5; 0.0001515585609009057;;], bias = [-1.2354196993157884e-16, -1.7755103554174617e-16, 3.906545472014906e-17, -8.989092480613113e-17, -9.840159197246915e-17, 7.958305409368797e-17, 8.142860847552779e-18, 3.546864545679575e-17, -2.1144005673140816e-17, -1.8025041942808042e-17, -1.6313487915646445e-17, -1.7848373273474803e-16, -3.971793926959653e-18, -4.556321370619653e-17, 8.994220616123972e-18, 2.5711165615825954e-18, 5.576912253983404e-18, 5.747843821333374e-17, -4.2314664589985386e-17, 1.3625572039886215e-16, 4.719600394911332e-18, -1.0716785075919338e-17, 2.243289125153811e-18, 1.5831826221344067e-16, 1.918316286151135e-18, -6.146864130915036e-17, 4.7274673781651933e-17, -2.5534610436659844e-16, 1.8259115059281116e-16, 1.3831930601026875e-17, -8.183296991278229e-17, 4.613001970148521e-16]), layer_3 = (weight = [-9.29512154318548e-5 6.120999188784664e-5 8.502481233950831e-5 -5.8428957511967606e-5 7.658049814164336e-6 0.0001533908160279367 -9.354585034383534e-5 -9.63682594669074e-6 0.00019500815853167259 5.144837797775793e-5 0.00011534956303582153 -0.00020734532716526767 -0.00015997833377990808 -8.649429053842473e-5 0.00016354684504837197 -7.425443453325034e-5 5.609550664008524e-5 -0.000104692377302163 3.3292946515540925e-5 -8.379538320080009e-5 -6.668751865003331e-5 9.46758244178996e-5 -5.628188790364265e-5 0.00013794379794195613 8.112318920826466e-5 -0.00019931229160941515 -3.301992098739117e-7 -0.00010414608385257328 -0.00010282110286719417 0.00011126869668874846 6.632112050974411e-6 -0.0001127947891138616; 3.5010714894407586e-5 -4.236314623708564e-5 0.0002346748698284035 0.00011905298390333779 -8.185438116638712e-5 3.5544399109452766e-5 -9.57994014941998e-5 -1.4153079081128738e-6 7.703499079356867e-5 -6.025923346610776e-5 7.161288169586925e-5 3.061128266525989e-6 2.8462327575787492e-5 -0.0001492138798101631 6.905685233791968e-5 -0.00011780538977914528 -6.23124032034449e-5 -5.022569519210635e-5 0.0001254644195097615 -6.973119514153363e-5 -6.261416626953558e-5 5.9997513856747776e-5 -9.212714565078766e-5 -0.00012743254399251634 -2.7439058500798802e-5 4.812973577205073e-5 2.6691046626318067e-5 -5.3408839336818026e-5 -0.00014211870063885986 -3.7503261298164715e-5 4.51051747615075e-5 -6.675557585990048e-5; 9.38820221838837e-5 -8.553052700308233e-5 1.4216360823685491e-5 -0.00013916194527127895 -6.082389970260824e-5 -0.00012739932660196468 -9.090230102601899e-5 0.00010047861635966798 0.00019332850845812382 -9.198575659838942e-5 0.00011467469360496472 0.00016560853603579128 8.956948207041954e-5 0.00010089355694645622 -0.0001440874775377723 -4.0838193826586696e-5 -1.8081250318302475e-5 4.826763897925064e-6 3.766944340013773e-5 5.51314215320615e-5 -0.00015045881534179182 2.0699691076786565e-5 6.462343210869028e-5 0.0001079080239199399 6.0072906335606504e-5 -3.144891249547985e-5 -9.836643673197567e-5 -0.0001472094745347014 -0.00023097037134874794 5.77687269745169e-5 -0.00013819344255325798 -8.962664375730099e-5; -1.16603459279766e-5 -4.165030525809715e-5 -0.00014865965106918987 -8.193780279879971e-5 5.9780265535940316e-5 -0.00012721311116520556 -3.6871812819020315e-5 -7.976652607948748e-5 -6.210223267016141e-5 9.568880438410538e-5 -0.00010957407260476057 0.00012060401156148954 0.0001576355799154992 -7.058105159713314e-5 8.443639041460962e-5 5.205680378667877e-5 -0.00013598471462655023 -5.997233921964987e-6 -3.632374039779924e-5 0.00019400801926920292 0.0002090076532734265 -5.837549445613428e-5 -0.00011205177627897956 -0.00016132885405722934 1.790771881888485e-5 0.0002747016342876057 -5.6898144004386374e-5 9.615303230776004e-5 0.000205511788470441 1.5031550054396932e-5 0.00020154637701880952 4.295722562191472e-5; -6.472864127662212e-5 7.307379534008692e-6 8.06021582922907e-5 -3.5061571665014224e-5 8.01706940057699e-5 1.570960763839094e-5 -5.730996575383081e-5 -4.649290289109697e-6 4.7507657897057136e-5 -6.234812402709721e-6 0.0002895805685204398 -8.262281129081736e-5 0.0001230593791353535 -4.3586403747932764e-5 5.5096187012360014e-5 8.772628849894197e-5 -0.00014173840257515746 0.00018000813931050156 -2.5374114314931468e-5 -1.4359915197750423e-5 -7.592788982381657e-5 3.6063744361143677e-5 0.00013697518649334478 -7.277714551836443e-5 8.291958717219919e-5 -0.00014920791521380372 -0.00010799226377801807 -4.3406814924113946e-5 0.00020627955588334434 -0.00016269691489813505 -5.521128853959599e-5 9.739422445473511e-5; 3.356575492860109e-5 5.373948610129204e-5 1.673603386650711e-5 1.507782129836186e-5 0.0002373586083550232 -5.0975037230204656e-5 -0.00011933395049793156 0.00015816935578102706 4.4871679928363515e-5 1.2497503884404649e-5 -3.0366126438138437e-6 4.533071645625896e-5 -0.00018521987127398655 -7.836711498468947e-5 -6.0969325350855854e-5 -1.9533405619962167e-6 7.005363098785768e-5 -6.317093553745633e-5 9.561515420416757e-5 1.0909617994066882e-5 0.00011379226438209904 2.0711917240149552e-5 0.00019754960325288145 2.009505427765946e-5 -7.307926913283276e-5 -0.0001655363104518304 -9.91147700560052e-5 3.153436213240428e-5 -6.648781178312278e-5 -3.279473654485765e-5 -0.0001107187655733664 6.872838806800661e-5; -4.057190930696361e-5 0.00018409300351406426 6.112264050527477e-5 8.345323836321942e-5 2.415536251275818e-5 -1.6045121850840168e-5 -0.0001404367667820007 -3.845142604775592e-5 0.0001508406548563688 -5.896105180370445e-5 -0.00012463351781123695 4.590854401241385e-5 7.80374375207635e-5 0.00014257804828586622 3.804850890075547e-5 -1.1100739397148664e-5 5.13070753694809e-5 4.864878970322372e-5 7.749491301721897e-5 2.6490924898017083e-5 -0.00015397627616409962 -0.00010974669752984132 -0.00011657514009500675 -1.3887474819130058e-5 3.445358193487108e-5 -3.4705821275640906e-5 -6.622175173477224e-5 5.324817680886782e-6 1.6542075132671526e-5 7.283482951236112e-5 -3.225311415378771e-5 4.105526564300539e-5; 8.957254836408624e-5 -0.0001657207008467202 0.00015862471123531896 0.00019210938121963725 -9.285563787139491e-5 0.00015172347999017655 -9.65028280408238e-5 2.910286134997086e-5 5.724857188019668e-5 -9.891458243523959e-5 -1.4836564827567293e-5 0.00017964399235628092 -4.1166929390015426e-5 6.744297793906069e-5 9.699081643324424e-5 -0.00015328645308195838 5.107867260071215e-5 -2.747405919231249e-5 -0.0001142333617682711 -0.00012764142370012764 -4.538774878178954e-5 -0.00019767925327337577 1.7970684399048372e-5 -4.195263458692279e-5 -5.537976680366554e-5 0.00015498078069489242 0.00017659714144960546 -0.00010731642900660155 -0.00019066330928050267 3.140983467878505e-5 0.0002564300252341547 0.00019595354611359297; 3.364858201739628e-5 -7.183764240560363e-6 0.0001623313975745101 4.940825527801515e-5 -5.261508778284887e-5 -8.193534887477593e-5 -2.2463086510193813e-5 -3.95203322238527e-5 -0.00015143662022973906 6.760282411624184e-5 -0.000153489211528449 -1.3958007742713648e-5 0.00016209606400143726 -6.720055245177689e-5 -8.894927015144176e-5 0.00019551548319787223 -0.0002360356346013726 -5.861292885465114e-5 0.00014952923373703252 -4.2461568967795525e-5 0.00012097784491407192 7.366174685221603e-5 4.498273409472285e-5 -2.8032593776241285e-5 -6.053252656210708e-5 -7.63376054362198e-5 -3.045654267561725e-5 3.9367819396195756e-5 2.928306565654531e-5 2.9768768569104884e-5 0.00015153841667261405 2.7426435905272068e-5; -5.397700079076255e-5 0.0001763106229173783 7.575130802987401e-5 -9.290140224212525e-5 0.0001547681567882033 -0.00018847693354648393 -2.2942416764549793e-5 1.4026796632690875e-5 -6.900666994798085e-5 0.00010529443897948513 -4.698035263747968e-5 5.5513119148659846e-5 7.922445001721838e-5 8.687105946416935e-5 -3.1729639448902143e-6 1.3172456418119054e-5 8.329922638776623e-5 -5.6215667429483584e-5 -3.491165989718132e-5 -2.980529825613278e-5 8.630154115789488e-5 5.9136487832888076e-5 4.044808883745147e-5 -5.072210115205787e-5 -4.328377690572372e-5 -7.281500075071374e-5 0.00012332581900235669 0.00025270598052710654 1.9558035999644046e-5 2.3898990546598305e-5 0.00012074425103318748 -3.2725100028918446e-5; 1.378234112049269e-5 0.00017273095574634045 6.726613202769297e-5 -0.00012180001947857458 0.00012602664202855958 5.0214870789159296e-5 0.0001148639915481653 -1.676710024187744e-5 6.218155460371403e-5 4.948729685563364e-5 6.786154546711919e-5 -4.378199882978826e-6 4.406467844088789e-5 6.405818414345743e-5 -3.1337143518655166e-5 4.5746923497037616e-5 4.417608062792336e-5 0.00019558235209840094 -1.7729126330194618e-6 -0.00015497829161184373 8.772708526518903e-5 0.0001605449636553856 -4.588955271199678e-6 -7.961163071744826e-5 -9.880636462210587e-6 -0.00017678430747769682 -1.5007388974512656e-5 6.803894058972702e-5 -0.0001253589450062078 6.717407661196447e-5 -8.578737622886247e-5 0.0002301878688012438; -0.0001460593563535404 0.00016117173740331172 -3.99270916768653e-6 -5.651859524386887e-5 -1.4579025940534865e-8 2.463478571717773e-5 -9.180947441218693e-5 0.00017270696784525916 -0.00010280246411294015 0.00011365721670135585 -5.147124889498783e-5 0.00011310916974599672 7.04783954859519e-5 0.00016307639393181318 0.00019447121971443774 -9.468497651705792e-5 -7.549083298381688e-5 3.865064294932147e-5 -0.0001294815234870832 -5.147327524917157e-5 1.8694833983133588e-5 -3.1436412939025526e-6 3.991190109790958e-5 -0.00019338352345239028 1.958871084188767e-5 -4.623158443445441e-5 -1.7817845815625383e-5 -9.868971451960511e-5 -6.097303108271221e-5 -0.00011558194879053388 -6.548639126431443e-5 -0.0001294796317380953; -2.651057359643754e-5 0.00012617329054747365 -3.7040255649961324e-6 -5.779857376937147e-5 -3.1313729125873888e-6 -7.446364172964979e-5 8.816475026079303e-5 -0.00010766852931582564 0.00011405843549410661 -9.749121811936761e-5 0.00029193587403979394 2.031466835587782e-5 9.216395125959736e-5 4.9080816162771296e-5 -0.00021901704160727178 0.00013768807554852998 -6.587623641125078e-5 0.00012264180265602497 -1.5882362211680437e-5 -1.660574155567132e-5 -1.1045572071717226e-5 -8.925797731380133e-5 -6.41874866489988e-5 -1.0838289134239463e-5 -5.9652156713392227e-5 -1.432420315568144e-5 -6.894832033894304e-5 0.00013277775003344458 -0.00010833021945318063 3.663466856933955e-5 0.00011571167858320123 2.013152263045774e-6; 6.901112624265795e-5 -3.501426203936956e-5 -3.309470071170204e-5 -0.00016759624873000854 -0.00011698681942132987 0.00014502308454186535 -3.071763808319371e-5 -0.0001616646843561153 -9.776049432756402e-5 2.8463784246663322e-5 0.00010942936910371278 5.8768182270549115e-5 7.564345992224465e-5 -0.00011635955911540461 -5.9453177755384415e-5 5.087170916314209e-5 5.851556102217153e-5 -4.831203500569813e-5 4.655780480763215e-5 7.690201141461889e-5 -1.2895293594089142e-5 -2.350244199104629e-5 -9.267716836828766e-5 6.350212038316386e-5 6.202675262985726e-5 -9.534248988554734e-5 -4.354239376906106e-5 2.2024398049067158e-5 6.191804982311258e-5 -9.159417118124646e-5 -1.3410251313222125e-5 -7.77939569618635e-5; -2.5563326089840093e-5 -0.0001249839536107303 6.324686702972479e-5 -1.905207525889185e-5 -2.0072833904501553e-5 -0.00015745611588536347 6.433607788456153e-5 -0.0001680336120265751 -4.444990398080531e-5 -3.843266520600279e-5 -8.370661525077222e-7 4.663914452940758e-5 0.00015577685709975642 1.9985007129070882e-5 0.0001296030838780958 6.445901974036834e-5 -0.00010283297731841427 -3.447241240830586e-5 -0.00013929389959268086 -9.280625661665332e-5 -9.577439621391227e-5 0.00016670618944733895 -0.00017418585619491045 9.281569666162937e-5 -3.479355862548093e-5 0.00010553567128190696 1.3258340658997561e-5 -0.00010083096120107828 8.925980520835257e-5 -0.0001562293603277987 -0.0002139178234578336 -0.00010829989176308747; 1.2883202342178623e-5 0.0001085847113956561 9.055464104515206e-5 -2.5749831567530754e-5 3.2634354991594246e-5 0.0002126842453237317 -5.452342722275768e-6 9.156732337781713e-5 0.0002097588010458256 0.00015381589835509465 1.3373620984750486e-5 0.00018094467011741278 3.7891769150728866e-5 8.507490636311818e-5 -4.254685629511089e-5 -6.4259919259894e-7 2.2224749302384264e-6 7.252873981836956e-5 6.450936851090471e-5 7.280510251635873e-5 -3.128900718526169e-5 -3.7879615063159017e-6 -4.5924791463785255e-6 -7.851503252985289e-5 4.4936544125664974e-5 6.495455889488373e-6 -0.00011297904640902802 0.00010393402115207973 0.00010857603845415295 -0.00017587575641461053 0.0002536025865795063 -0.00011499269678261277; 2.5260484963187994e-5 3.075855775287239e-6 -3.4647980155304124e-5 -5.933500602596634e-5 5.046361311807745e-6 0.00029479999551793387 7.272479610064391e-5 -3.1055716189898456e-5 0.00010271336845134703 -4.534050554839576e-5 8.731210158683009e-5 -7.195712629229829e-5 -9.174630491516572e-5 0.00010211684906633244 -0.00020540851352026637 2.128652968437288e-6 -1.803733271999899e-5 5.370661432634412e-5 -2.2485733152015113e-5 -0.00015086204171870285 -0.0001407913596810461 0.00010784774980520214 3.848648193795004e-6 -0.00015339985207875933 1.945337717458343e-6 0.00021009812820946277 8.320809769456212e-5 0.000129053782930252 0.00011852265646691699 9.757136189391154e-5 1.6803168309976363e-5 0.00015577801144925303; 8.17783627048155e-5 7.962732769959053e-5 1.9723348208963278e-5 -8.370194273590637e-5 4.487224641200585e-5 0.00012044436579747121 -0.00019228887986679857 -0.00015570309488390925 0.0001120194617164385 -5.3444814789661344e-5 7.75412160641505e-5 -3.0266806157286388e-5 5.5167446311707795e-5 5.782815850982365e-5 1.54811629676502e-5 -5.5156803960532985e-5 0.000119300075943495 0.00013956920583444343 -9.391727808685563e-5 -2.7677703934009373e-5 -2.3242369302868023e-5 -5.221107431276996e-5 9.694525642265277e-5 -9.663594692535509e-5 -3.812547520789609e-5 0.0001391778029705381 -0.00015574337458526248 7.725954919299555e-5 0.00011411104487199104 3.277464265414189e-5 -9.82839440490057e-5 6.829887269707208e-5; -2.6735100397831976e-5 -5.4775345922444264e-5 4.6787181638882366e-5 3.0416047073403896e-6 0.0001415951576225946 -3.590462234312236e-5 5.222999811195413e-5 -0.00018263701145514112 0.0001273963300653233 -0.00016346573625668302 9.145759599308206e-5 -0.0002397941987120775 -8.870138640957826e-5 0.00013684532518057775 -8.76669998953474e-5 0.00011170096899405694 -3.58011582258658e-5 -8.12963824111536e-5 0.00020758986851689666 -0.00014675871148678323 2.460920790582539e-5 0.00012480440113250753 -8.83879236036482e-5 -4.761255855298686e-5 1.2589108499124351e-6 -3.499921309154556e-5 -8.247168967244104e-5 -3.849427753919572e-5 -0.00021376364544008756 -6.73045631239683e-5 -3.4852828100326725e-5 -0.00012139191472412101; 0.0001441887741565351 -4.734196974834115e-5 -6.695908652556729e-5 -9.734924080487011e-5 2.404446220669311e-5 -0.00023295639968648498 2.08606759587212e-5 -0.00011780802580698874 0.0002354273035779272 0.00013329184892416491 7.271711527198551e-5 7.488941062535892e-5 -0.00015709441342562054 5.211431725798236e-6 0.0001283345207228896 3.9336640638686025e-5 -0.00012806290598757077 -0.00018139986625112108 -3.2855276204027616e-5 -7.959594060605974e-5 -2.2046881185923466e-5 -1.0971743195222015e-5 0.0003112084539496948 -0.00011755578290842289 1.1713521216626048e-5 2.2838524982830917e-5 0.00019422936182298138 -8.774460568030939e-5 9.03194757300978e-5 0.00014290650304235982 -7.745005697881327e-5 1.56407898418828e-6; -0.0001496875797338952 -1.7146451183244077e-5 8.527634558861669e-5 8.792756630003051e-5 -0.00017844699546101044 0.00021325800985227166 -0.00019502335858403845 -7.105869958748082e-5 -0.00011773534283845817 -4.421873968665621e-5 1.4553620120699666e-5 0.00012105175624033402 9.858129264914736e-5 -4.3531885649878244e-5 -0.00011841449527406579 -4.172440595610773e-6 0.00018186957236041186 -0.00023867887143849508 -5.913519494340241e-5 -4.2592628643459826e-5 0.00020021255254407274 3.594675358628726e-5 -1.564633426012613e-5 -1.4914563008519137e-5 3.665974650077385e-5 -0.00010620232428759992 -0.00021418895853555402 -7.042909642320643e-5 0.0001395850681992201 -7.822418633720878e-5 -7.930323995522307e-5 -0.00021559090460055163; 2.2437507839782795e-5 0.00011391954355363233 -4.8042156078597446e-5 0.00021748693576969591 3.880254342359214e-6 -4.736739832136663e-5 0.00013206637847192594 8.573607766716564e-7 -7.43129505570791e-5 -7.408524218758563e-5 8.250131196850619e-7 -0.0001340615647057651 -1.9158739922647425e-6 0.00020236637534995792 -1.1668302460604274e-5 -0.00014872065478980536 -1.6821971468608743e-5 -3.089656932556984e-5 0.0001542337474967577 -0.00018958570693196014 -1.1465597009966121e-5 -3.535804829139253e-5 -3.734341337972319e-5 -3.1987785694977507e-6 8.946486348712385e-6 -7.72002542651873e-5 -1.3583327228705785e-5 3.977633857007458e-5 -5.9482482567396786e-5 4.7049996981649306e-5 -0.00022608295806262604 -0.00017676999526141732; -0.00015645805780879516 -2.7481664637912073e-5 -8.746892655838173e-5 6.810158281948532e-5 9.93322266185475e-5 5.0029381553568e-7 0.00015381616156443463 4.69531207090489e-5 -0.0002491036701212871 6.152168696641437e-5 -4.0815366282924294e-5 -5.8151444882244934e-5 -8.651328773342163e-5 2.450306254103262e-5 0.00010624041367917306 -1.5072055994125815e-5 4.7589923433377244e-5 -0.000215589328781818 -6.969623593252051e-5 7.848569128328786e-5 5.255480771663108e-5 -3.3461381299304925e-5 -3.19647568358344e-5 -6.185109470608912e-5 -0.0001394677509051983 -0.00014470536354566286 -1.4076122006397168e-5 -0.00010706515406129873 -9.967977332053831e-5 3.231088538083776e-5 4.338093474868244e-5 6.193910137880359e-5; -4.860219336369398e-5 -9.73436697120658e-5 -0.00019401392704712948 -9.166177495720495e-6 5.052560960502535e-5 0.00013105503536390214 -0.00016451607438176176 -2.677401311194181e-5 7.702539665951166e-5 -4.977392812980568e-5 -4.771072848488348e-5 5.449092283125828e-5 6.80579753508891e-5 5.272649219587651e-5 -5.747217426658452e-7 -0.00016872217451128414 -0.00020854480020608053 -0.00016705283700394662 -0.00016144775940011048 1.844791155347146e-5 5.17340588532458e-5 -0.00010553173409681059 3.3582909292082735e-5 1.696142613691442e-5 -1.7266715748345013e-5 -5.213118923965469e-5 0.00014631699674391287 1.3273609748644755e-5 0.00013994465485776345 4.9392750280451195e-5 6.31469149640835e-5 -2.1888529535127354e-5; 2.7890870731409754e-5 -8.055327027433946e-5 6.793356322121392e-5 -0.0001744564227210679 2.5492655075122933e-5 0.00016704406760071285 -0.0001365573419407348 -8.132815248429386e-5 -4.115067129270112e-7 -2.6572786120615214e-5 -9.32205742374348e-5 -0.00016645180705564876 0.00011331539160708425 -2.9579546500840918e-5 0.00010032218295053388 -0.00020685851752388876 6.183495741241029e-5 -5.412735601750781e-5 4.890253210576764e-5 0.00016726019264569133 -9.258497568409543e-5 7.716956195325008e-5 -8.762583576937791e-5 4.7844600592709346e-5 4.64970859666017e-5 0.0002047002187082264 5.583768386530232e-5 -3.9365867743999485e-5 -9.254310982397714e-5 -0.0001072043317635879 -0.0001720646553822198 0.00021492286639657814; -3.835797068307557e-5 -6.599216684110343e-5 -4.5978135662086104e-5 -9.987014599021322e-5 0.00010175483222708376 -2.717939131247379e-5 6.47883109341649e-5 8.189907131607422e-5 0.00012003136836581797 1.0536241658047584e-5 -4.254997364472867e-5 2.710107875468499e-5 -9.365531403448004e-5 -4.970282569405859e-5 -1.54729866916895e-5 -1.6655020575780062e-5 9.684000289424253e-5 9.468031677541237e-5 0.00017907994080458055 -2.3652501826262263e-5 0.00019751542751335476 -0.0002001873970438454 3.0235499449264628e-5 -7.937271024214031e-6 -5.2108628355125564e-5 5.116090434608496e-5 0.00011208834368160369 0.00013876615886699727 8.329116573445753e-5 1.0779043092116086e-5 9.960283679545742e-5 0.00013685793711925304; -0.00016073415577342391 -9.00790194042587e-5 -1.1094081177037015e-5 6.628200094051591e-5 -7.047069198383138e-5 8.570273962174369e-5 -0.00021499070976792445 3.2351951400735784e-5 -0.00011664334420279208 -4.8341867344468664e-5 3.8427521527744774e-5 0.00012811542432825666 2.518769974481549e-5 -0.00014835006846045175 -0.00013285449063335912 0.00016146425218389754 -1.4769698512871563e-5 3.39562818647682e-5 -1.305226965418399e-5 3.7357213610800565e-5 -0.00017603658602970265 1.738496461835151e-5 3.9764140407307174e-5 7.952924238121504e-5 0.00012718085667652468 -5.613045619032198e-5 1.8823303204367815e-5 -0.0001950247981345855 4.403476009099097e-5 0.00020666730301523105 -0.00018628622752080244 -9.955554674309484e-5; -2.866157530245846e-5 0.00010880809412316485 1.1460269601743181e-5 1.189813945944888e-5 8.705449425593899e-5 -0.00011613116959363241 -3.435055197354926e-5 7.76603162053654e-5 7.106969559077589e-5 -8.978010311279032e-5 -0.00011491075139513074 2.9901496131591302e-5 -0.00014071804190774622 0.00012186781230827933 -0.00016721462026491997 1.51013325053634e-5 -3.117896204952664e-5 -0.00011506368474823925 -1.7769437650200076e-5 -0.0001640470030137317 -0.0001286186482643109 -3.03451282119942e-5 5.9566817694357014e-5 5.618949839635903e-6 -6.0884164954987897e-5 0.00019867330837276164 -4.361886687132661e-5 -1.0015823423088007e-5 5.402327286408764e-5 -6.822688311825956e-5 -0.000138038146647378 -2.7807283291095745e-5; 3.450011571157434e-5 8.05100884322117e-5 3.681121814633146e-5 0.00024383509515654657 -4.045582828940498e-5 -0.0001864974161576387 -0.00021689990318510648 -2.7794234134046134e-5 -2.5838605884485485e-5 -2.1284393019880714e-5 8.514309809879507e-5 -2.383772663551801e-5 -2.5830480458817735e-5 4.184028728968555e-5 -0.000140645277147546 -9.172801233501758e-6 5.0975786857167665e-5 7.921398572260559e-5 0.00011400776782874372 3.515772403671766e-5 -8.598594796741348e-5 -4.360984313781332e-5 5.267089580637437e-5 8.078640698380354e-6 -0.00013314682074952837 0.0001230665969928841 0.00018642538851462584 -0.00015590076292529398 -5.620892772230947e-5 -9.518739317674109e-5 -0.00016970994840040743 -0.00013047909089027658; 5.6402394302743005e-6 -0.00024304249396934975 -2.923811567868998e-5 -0.0001805869252739097 -1.133844455562335e-6 2.8614651307210727e-5 1.9849696252106736e-5 -5.135389676966116e-6 2.706471456816072e-6 -7.390302213296934e-5 4.8142475713608826e-5 -7.034796741477648e-5 5.532024427977202e-5 0.0001438392499376607 -2.6316409423970265e-5 -7.161183763217829e-5 -5.990263176771519e-5 -0.00019316673584670447 2.1277712073261215e-5 -6.783125005580187e-5 4.159403018730976e-5 7.551021538985257e-5 -5.1806200793168495e-5 -5.104337484496898e-5 -6.378934646180546e-5 2.6620789719380848e-5 -5.1539045819464524e-5 -0.00012811571665495705 7.717090178361213e-5 -4.2569823711333685e-5 -4.08317709331035e-7 -2.9910670639728965e-5; -3.245023298218475e-5 8.120280064035757e-5 7.112308913933695e-5 -2.3053475604638258e-5 -3.762003733568222e-5 7.06197602194564e-5 0.0001413748318664625 9.205726338409644e-5 1.588418991746263e-5 -2.0352496438088142e-5 8.741862940227693e-5 -0.00010529472682534001 -1.4650790479162637e-5 -0.00016896620990151449 5.840190307212325e-5 7.858173877325724e-5 2.0543091079209678e-5 2.3798641916121046e-5 0.0002459389138603455 0.00014217475064654979 -0.00017602298656852099 -0.00010113260985958511 -0.00011787666925369879 -0.00011244391543000575 0.00013541436738892344 -0.00014108862568533047 0.00012702761842961133 -7.466910800571867e-5 5.218337492992108e-5 -8.02805924463235e-6 -4.353718439956331e-5 3.18423840435965e-5; 0.00020087118433743449 -0.00013456320915416147 7.493971030648536e-5 -3.533255688658244e-5 -1.8958359883849318e-5 -0.00018102286152120952 -3.683611169968055e-5 0.00012087921901561061 0.00016211531589228404 -1.3993763000921106e-5 -3.5012891529318047e-6 3.857279704256637e-5 9.580835727696978e-5 -4.682656902864846e-5 0.00012421809775709836 7.328319303646426e-5 -2.7532401087095723e-5 -6.027012783173931e-6 3.803207697651089e-5 8.195728411387362e-5 6.0585504674368786e-5 -0.00011619728060018528 -9.549742857039232e-5 0.00020748172652397038 -5.6234360395914626e-5 -0.00011703475787349319 -8.850147890903316e-5 -0.00014934751846872534 6.81625341386693e-5 -0.0001694919966538276 -0.00014868577012371528 0.0001050744711975726], bias = [-9.660500295423935e-10, -1.3059223912969864e-9, -1.5288330379358577e-9, 1.7835812350454296e-9, 1.4427007215004774e-9, 5.000757288649013e-10, 8.567039220942696e-10, 1.2682502041943338e-9, 4.758351712520487e-10, 2.804879927239147e-9, 3.4618278477791052e-9, -1.5358238694921515e-9, 8.222484294299286e-10, -1.433591284143692e-9, -3.0834420815829706e-9, 4.933807065773912e-9, 2.4655858366801094e-9, 1.1174910765995936e-9, -2.966550880917125e-9, 7.583183844233103e-10, -3.14544293240116e-9, -1.8128338083526715e-9, -2.8501927156923083e-9, -2.1325275232403452e-9, -8.015577088005366e-10, 3.155233645254626e-9, -2.271238962090006e-9, -2.46228904987297e-9, -1.52578242089371e-9, -3.2452253317412057e-9, 1.127329644541343e-9, 2.4432798157661613e-11]), layer_4 = (weight = [-0.0007752528964381777 -0.0008314202336795446 -0.0006305359260482846 -0.000815046221384134 -0.00065865158864451 -0.0008182892056172394 -0.0005810474435219902 -0.0006370549391940377 -0.0006365245584713429 -0.0008506117959505727 -0.0005757027303773267 -0.000593061655499317 -0.0005572661085503353 -0.000769854500289913 -0.00076143898260666 -0.0008669930112997963 -0.0005848652819498644 -0.0005975101655130368 -0.000774237274591768 -0.0005196720821828705 -0.00081774704751006 -0.0005505443289940173 -0.0006462603548291704 -0.0007508932940670396 -0.0006742257670611891 -0.0007366599315207597 -0.0005811890698999547 -0.000814719370489491 -0.0007541245855309971 -0.0007871026645440827 -0.0006231215029914905 -0.0006402519641720713; 6.008354510099205e-5 0.00036298863850916786 0.00019373069231320612 0.00029192219469220524 0.00022010947126783974 6.163846633573222e-5 0.00023794628626988838 0.00022206845438654831 0.00021063661530346917 0.00022023003120661056 0.00016272190379365317 0.0002602564425485357 0.0002183072161507469 0.00022674302979106283 0.00032935191728797244 0.00022133158822993008 0.0002648362437981267 0.00031888378185888325 0.0001804264521830984 0.00015285271367346208 0.00019176198004120575 0.00027268074457858604 0.00034103432530511893 0.00015855196083732132 0.00028775696327108466 0.0003558796408204822 0.00012572289598720315 0.00024785181567917746 0.0003615439924101246 0.00044113965413482655 0.0002563028118449078 1.592202596475225e-5], bias = [-0.00067935801719514, 0.0002221335932435822]))Visualizing the Results
Let us now plot the loss over time
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")
lines!(ax, losses; linewidth=4, alpha=0.75)
scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)
fig
endFinally let us visualize the results
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
s3 = scatter!(
ax,
tsteps,
waveform_nn_trained;
marker=:circle,
alpha=0.5,
strokewidth=2,
markersize=12,
)
axislegend(
ax,
[[l1, s1], [l2, s2], [l3, s3]],
["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
position=:lb,
)
fig
endAppendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.12.6
Commit 15346901f00 (2026-04-09 19:20 UTC)
Build Info:
Official https://julialang.org release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 4 × AMD EPYC 7763 64-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-18.1.7 (ORCJIT, znver3)
GC: Built with stock GC
Threads: 4 default, 1 interactive, 4 GC (on 4 virtual cores)
Environment:
JULIA_DEBUG = Literate
LD_LIBRARY_PATH =
JULIA_NUM_THREADS = 4
JULIA_CPU_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0This page was generated using Literate.jl.