Training a Neural ODE to Model Gravitational Waveforms
This code is adapted from Astroinformatics/ScientificMachineLearning
The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl
Package Imports
using Lux,
ComponentArrays,
LineSearches,
OrdinaryDiffEqLowOrderRK,
Optimization,
OptimizationOptimJL,
Printf,
Random,
SciMLSensitivity
using CairoMakieDefine some Utility Functions
Tip
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector
function one2two(path, m₁, m₂)
M = m₁ + m₂
r₁ = m₂ / M .* path
r₂ = -m₁ / M .* path
return r₁, r₂
endNext we define a function to perform the change of variables:
@views function soln2orbit(soln, model_params=nothing)
@assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
if size(soln, 1) == 2
χ = soln[1, :]
ϕ = soln[2, :]
@assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
p, M, e = model_params
else
χ = soln[1, :]
ϕ = soln[2, :]
p = soln[3, :]
e = soln[4, :]
end
r = p ./ (1 .+ e .* cos.(χ))
x = r .* cos.(ϕ)
y = r .* sin.(ϕ)
orbit = vcat(x', y')
return orbit
endThis function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d_dt(v::AbstractVector, dt)
a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
b = (v[3:end] .- v[1:(end - 2)]) / 2
c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
return [a; b; c] / dt
endThis function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d2_dt2(v::AbstractVector, dt)
a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
return [a; b; c] / (dt^2)
endNow we define a function to compute the trace-free moment tensor from the orbit
function orbit2tensor(orbit, component, mass=1)
x = orbit[1, :]
y = orbit[2, :]
Ixx = x .^ 2
Iyy = y .^ 2
Ixy = x .* y
trace = Ixx .+ Iyy
if component[1] == 1 && component[2] == 1
tmp = Ixx .- trace ./ 3
elseif component[1] == 2 && component[2] == 2
tmp = Iyy .- trace ./ 3
else
tmp = Ixy
end
return mass .* tmp
end
function h_22_quadrupole_components(dt, orbit, component, mass=1)
mtensor = orbit2tensor(orbit, component, mass)
mtensor_ddot = d2_dt2(mtensor, dt)
return 2 * mtensor_ddot
end
function h_22_quadrupole(dt, orbit, mass=1)
h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
return h11, h12, h22
end
function h_22_strain_one_body(dt::T, orbit) where {T}
h11, h12, h22 = h_22_quadrupole(dt, orbit)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
h11 = h11_1 + h11_2
h12 = h12_1 + h12_2
h22 = h22_1 + h22_2
return h11, h12, h22
end
function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
# compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2
@assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"
h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
@assert mass_ratio ≤ 1 "mass_ratio must be <= 1"
@assert mass_ratio ≥ 0 "mass_ratio must be non-negative"
orbit = soln2orbit(soln, model_params)
if mass_ratio > 0
m₂ = inv(T(1) + mass_ratio)
m₁ = mass_ratio * m₂
orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
return waveform
endSimulating the True Model
RelativisticOrbitModel defines system of odes which describes motion of point like particle in Schwarzschild background, uses
where,
function RelativisticOrbitModel(u, (p, M, e), t)
χ, ϕ = u
numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
denom = sqrt((p - 2)^2 - 4 * e^2)
χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = numer / (M * (p^(3 / 2)) * denom)
return [χ̇, ϕ̇]
end
mass_ratio = 0.0 # test particle
u0 = Float64[π, 0.0] # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4) # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize) # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, eLet's simulate the true model and plot the results using OrdinaryDiffEq.jl
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)
axislegend(ax, [[l, s]], ["Waveform Data"])
fig
endDefining a Neural Network Model
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but in case you do use them, make sure to mark them as const.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,
const nn = Chain(
Base.Fix1(fast_activation, cos),
Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-0.0002918029; 0.00016396421; -0.00015462728; -0.0001412375; 2.1553558f-5; -8.221679f-5; 1.6669823f-5; -0.00021235233; -0.00021415077; 4.339119f-5; 0.000106239524; -0.00013206032; -5.365259f-5; 8.104361f-5; 9.736526f-5; -6.860325f-5; -3.5452786f-7; -9.2317554f-5; 2.4067065f-5; 0.00011212931; 8.158583f-5; -9.260551f-6; -7.6449774f-5; 2.552661f-5; -8.425936f-5; -3.1369822f-5; -1.913807f-6; 0.00010249789; -9.1437476f-5; 6.349012f-6; -0.00017337427; 0.00025099676;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-3.459632f-5 -9.511012f-6 0.00018080286 -0.00017042752 0.00010780062 5.7241825f-5 -0.00015344778 0.00011164283 -0.00018069279 0.00013770127 3.3959885f-5 -5.8425863f-5 0.0001965221 -7.742922f-5 -0.00019111208 -6.179408f-5 5.990949f-6 -9.063119f-6 5.2834825f-5 -2.2616383f-5 0.00016521718 0.000107877735 -5.210156f-5 -0.00014179201 -1.1983214f-5 0.00013712845 -5.6959496f-5 -4.530252f-5 -4.7227226f-5 2.3852219f-5 -0.00011203043 -7.825932f-5; -9.077223f-6 -0.000110150715 -0.00013953866 6.401513f-5 -0.00012878446 4.2585456f-5 -1.8357257f-5 -0.00024089502 3.691254f-5 0.00011573212 2.440072f-5 8.431817f-5 5.1518327f-5 -0.00014254195 0.00014375376 9.816972f-6 0.00013285833 0.00012450144 0.00012131429 -0.0001203604 0.000111805224 0.00023625417 4.6467943f-5 -1.4697964f-6 3.718658f-5 1.8928038f-5 2.7456033f-6 3.135674f-5 -6.942362f-5 9.974606f-5 -3.952018f-5 6.0235765f-5; 0.00011096574 -0.00018745176 -2.313079f-5 2.3567648f-6 0.00018331567 -2.3246983f-5 -0.00017054955 -7.912776f-5 -1.0885672f-5 -0.00017306971 0.00011542462 0.00013101599 -0.00019418065 -7.959536f-6 -1.3115851f-5 8.8250876f-5 3.3426552f-6 -4.5119745f-5 0.0001506827 -0.0002014937 8.3395316f-5 -4.570943f-5 7.265096f-5 3.2685773f-5 -7.440379f-5 -3.019198f-5 -0.0001391624 -0.00011491972 -6.0296232f-5 5.797064f-5 -0.0002561306 -7.2498275f-5; -0.00015928454 1.1530831f-5 -0.00012295511 1.4254113f-5 8.271654f-5 8.5859814f-5 -3.9888942f-5 -7.45323f-5 0.0001789882 0.000108929 -0.0002226897 -8.599393f-5 5.8038648f-5 -0.00019867028 6.729028f-7 0.00013247477 0.000119685224 -9.812513f-5 2.5564597f-5 6.847332f-5 2.1398462f-6 -1.8658711f-5 -0.00013953423 5.16256f-6 -0.00029467445 1.9756464f-5 0.00021108448 -9.5832715f-5 -3.6613f-5 -3.237712f-5 0.00018019116 2.0336023f-5; 6.1386534f-5 0.00019318327 0.00010950063 3.7606747f-5 -8.347488f-6 2.3650678f-5 -3.371944f-5 4.7850954f-6 6.118861f-5 5.8611524f-5 8.6605956f-5 -0.00013331867 1.5296564f-5 9.123018f-5 0.0002995452 0.00013557666 7.755482f-5 1.7363423f-6 -8.685375f-5 -7.510023f-5 -8.300563f-5 -0.0001251381 5.9423295f-5 -1.818171f-5 1.9650559f-5 6.6088585f-5 8.331866f-5 -9.3515984f-5 6.621654f-6 -3.986251f-5 0.00011874825 0.000113477385; 4.588589f-6 6.8714926f-6 -4.2249535f-6 -1.4950157f-5 -0.0001450161 0.000112961534 -0.00011599532 7.455297f-5 4.0074796f-5 0.0001223722 -5.6320016f-5 -2.739748f-5 -2.2885295f-5 -1.864261f-5 3.3413264f-5 -2.2927876f-5 -0.00019113542 -0.00013768561 0.00012264214 2.824638f-5 -0.00012692228 2.4045668f-5 -3.9597457f-5 0.0001876386 -0.00013007747 0.00023284399 -4.63177f-5 -4.819755f-5 -7.420977f-6 -4.468547f-5 3.2607786f-5 2.6385447f-5; 0.00019794551 -0.00010020064 0.00015737582 -3.0941053f-5 -0.00010871976 7.8603676f-5 -0.00011579636 -0.00011793036 -6.6104094f-6 1.9156038f-5 -5.902179f-5 -3.5747857f-5 6.5095264f-5 -7.8468984f-5 -1.9715104f-5 1.9111283f-5 -1.1972981f-5 7.214078f-5 -3.9180704f-5 -7.74959f-5 -2.6726653f-5 -2.532606f-5 3.5238427f-6 4.8100213f-5 -0.00014451862 -5.386706f-5 2.3550714f-5 4.6306057f-5 7.870675f-5 -0.00016158225 -6.511595f-5 0.00010404207; -1.4949642f-5 4.1103503f-5 -4.9647548f-5 1.2279523f-5 1.6756549f-5 -9.522368f-5 -9.341785f-5 -9.736942f-5 0.00012049216 5.776135f-5 4.501106f-5 -0.00011462717 -9.314351f-5 0.00015411506 -9.2134205f-6 -3.9952625f-5 0.00012417417 0.00011803014 1.3827782f-5 1.558044f-5 7.76975f-5 -1.4797531f-5 4.2093165f-5 0.0003342653 -0.00012017209 8.293041f-5 -5.7410074f-5 4.0415784f-5 0.00011312975 7.4652664f-5 -1.8112256f-5 0.00011219188; 1.3738051f-5 6.6912275f-5 2.0624144f-5 9.217131f-5 0.00036655957 1.12467815f-5 0.0001541046 0.00017168254 5.3317242f-5 -9.354058f-5 0.00011692095 -0.000118184806 -2.280535f-5 -6.244324f-6 8.609787f-5 -0.00013240684 -0.00017929001 4.179979f-5 0.00015553973 0.00012416519 0.00016793591 0.00016096492 -1.2762382f-5 5.96124f-5 -0.00014541938 -4.3202155f-5 4.0437993f-5 8.322974f-5 -0.0002236273 8.1855265f-5 -9.050302f-5 3.268254f-5; 9.564198f-5 -2.9971021f-5 -2.2197963f-5 0.000106039064 3.361783f-5 2.8208583f-6 -4.2388936f-5 -9.597139f-5 6.458772f-5 5.835857f-7 -3.195639f-5 6.24139f-6 -2.3944966f-5 -2.0748847f-5 2.1893926f-5 -8.804188f-5 3.667322f-5 2.0819944f-5 3.850359f-5 -4.9352682f-5 -1.6518557f-5 4.389409f-6 8.3009014f-5 -4.004834f-5 -7.821209f-5 3.9186427f-5 -0.00015528186 5.864928f-5 -0.00026284842 -0.00018987752 -9.9294564f-5 2.2062719f-5; -0.00017083064 9.657539f-5 0.00010298113 -0.00012581605 -3.0916995f-5 0.00012723784 3.7495196f-5 9.8387456f-5 5.8089387f-5 2.110895f-5 -4.286107f-5 -0.00012940075 5.143108f-5 3.97873f-5 -3.263793f-5 1.1533779f-5 -3.898364f-5 0.00013366286 1.6379294f-5 7.4834636f-5 4.9993138f-5 6.764863f-5 -4.1723142f-5 -6.36821f-5 3.058261f-5 0.00010481465 -1.1403645f-5 -5.2991845f-6 -0.00013618915 -3.972246f-5 -5.3914988f-5 8.268079f-5; -2.52669f-5 -4.2365646f-5 6.1155733f-6 0.00011792129 -8.695683f-5 5.1589795f-5 0.00016317256 6.399808f-5 0.00024367163 4.961889f-6 9.513909f-5 0.0001834422 -0.00015878266 -0.00024397936 0.00011017149 -5.6639376f-5 -0.000169033 -3.0776635f-5 9.4112904f-5 -3.07092f-5 -5.3147734f-5 5.2557625f-5 -6.6801105f-5 3.5853031f-6 -1.1650456f-5 0.00010431852 8.490821f-5 -4.4545308f-5 8.8124776f-5 0.00018932694 -8.326807f-5 -0.00017087423; 4.02413f-7 7.5899356f-5 -8.1459984f-5 0.000167482 -8.202622f-6 0.00016006293 4.5386238f-5 7.4157324f-5 -3.140602f-5 1.538384f-5 -0.00019276785 -5.9088063f-5 -4.337253f-5 -0.00014809637 -7.196804f-5 -0.00011189876 -7.408388f-5 8.206836f-5 0.00012624227 0.00022497849 3.0079833f-5 -6.219347f-5 7.723122f-6 5.5082037f-5 -0.00010730325 9.2920214f-5 -6.574005f-5 0.00019488347 1.4727085f-5 -3.4186323f-5 -0.00013003129 -1.9871213f-5; 1.13583965f-5 1.2385568f-5 -3.3078777f-5 -3.7544665f-5 6.307921f-5 0.00013665955 -5.330077f-5 0.0001258086 8.915043f-5 -4.633264f-6 2.5150948f-5 -0.00019619179 -1.3461448f-5 -0.00011995876 1.3276934f-5 -3.855669f-5 -7.4441756f-5 -1.4018065f-5 1.5118289f-6 8.2688464f-5 -0.00011862563 -7.38602f-5 -3.535617f-5 -0.00021681665 -6.90515f-5 -6.513429f-5 -9.055478f-5 -0.00013057457 0.0001069964 -5.5465134f-5 -1.8626448f-5 -0.00021904746; -0.000102515376 -3.329889f-5 3.857652f-5 -3.0060664f-7 3.2677108f-5 -0.00012664679 2.2739188f-5 -8.596474f-5 6.675199f-5 -6.1026258f-5 -0.00011214431 7.6623975f-5 -4.2041924f-5 0.00019428683 7.612907f-5 -9.7997894f-5 5.2880903f-5 -8.073828f-5 -6.368918f-5 2.2410346f-5 -8.753189f-6 -0.00010903239 0.00015142852 -0.00010289612 -2.145241f-6 -0.000134039 1.3414351f-5 -4.2197436f-5 -1.1421615f-5 -5.6299963f-5 -4.7910296f-5 -0.00019151655; 3.19487f-5 3.4406785f-5 3.2212047f-5 -3.386411f-5 0.00018021243 -1.7619795f-5 -0.000115967436 0.00012715462 6.029539f-5 0.00017630201 8.15452f-5 6.6519155f-5 -0.00018838108 2.8445835f-5 1.0520676f-5 -6.8198286f-5 0.000131858 0.0001635929 5.694575f-5 -7.1124094f-5 -2.6367965f-5 0.00014813032 0.00014589124 6.689709f-6 -0.00012756273 -8.3340215f-5 5.2655992f-5 -0.00023034637 -0.00012691006 -0.00014323353 0.0001610738 4.059911f-5; -1.3394767f-5 5.6268094f-5 -7.4520053f-6 -0.00011715962 1.0722031f-5 0.00022095993 0.00014368519 -3.3034343f-5 2.3645442f-5 2.118239f-5 5.2021594f-5 -0.00012501198 -6.63428f-5 -1.16567435f-5 -8.617916f-5 -8.992589f-5 0.00017250078 7.356051f-5 0.00016292081 -0.00010304836 -0.00010684991 -0.00011247021 4.2167747f-5 -3.346951f-5 0.000116116345 -3.2816977f-5 5.6969177f-5 -8.2696526f-5 5.9806865f-5 0.00018310297 5.8885995f-5 0.00014985439; -0.00014548501 -5.3577078f-5 -0.000102644895 -2.5584603f-5 -0.00018285982 -7.112601f-5 0.00018319592 -2.1849351f-5 5.373685f-5 2.838585f-6 4.977517f-5 0.00014602992 0.00022474483 -0.00016720995 9.436177f-5 0.00010611057 -2.2444023f-5 -0.00012473295 -9.612803f-6 0.00014296109 -1.1588689f-6 0.00022549236 4.973239f-5 5.8950103f-5 0.00014096564 2.3552037f-5 -6.8640647f-6 -3.3555243f-5 2.958516f-5 -9.942323f-5 -6.70455f-5 3.3955374f-5; -8.8101966f-5 3.7833135f-5 0.0001258931 1.5406227f-5 7.752525f-5 0.00011444477 2.3997474f-5 5.6456985f-5 -1.3731656f-5 0.000110488014 -9.060593f-7 -0.00011434729 -8.341201f-5 0.00014391419 4.5611796f-5 3.625062f-5 7.229019f-5 -1.5212598f-5 6.016332f-5 -0.0001406805 0.00018423435 -0.00010697063 5.7640016f-5 -5.847444f-5 -0.00030030336 5.421133f-5 -7.9287405f-5 -1.0043751f-5 0.00011785965 0.0001312531 2.5591627f-5 -0.0001594393; 3.704722f-5 9.191891f-6 -0.00012973219 -9.187165f-5 -3.7336446f-5 3.917717f-5 5.3289965f-5 -0.00012263034 -5.9285117f-5 9.4616575f-5 3.810336f-5 -3.768073f-5 3.8800114f-5 0.0003155684 -0.00017950605 -5.8545204f-5 1.15618f-5 -7.1766735f-5 7.6770964f-5 -2.3736178f-5 0.000116781026 2.2615784f-5 7.4701755f-5 -0.00010225712 -0.00012219072 -5.399725f-5 -0.00018572407 -8.923338f-6 -0.00010648547 3.528165f-5 5.2504847f-6 5.5094886f-5; -0.00010723806 -5.22317f-5 3.907536f-5 -6.22843f-5 -6.453198f-5 0.00017074752 -6.0573675f-5 -9.488635f-5 0.00024170538 -6.384209f-6 -2.8200218f-5 -7.605601f-5 -2.7428132f-5 7.3143696f-5 0.000118863114 -4.4588083f-5 -1.4372318f-5 -6.1905914f-5 -0.000170621 -1.02216045f-5 -9.99867f-5 -1.103301f-6 8.417026f-6 0.00021875335 -0.00015549749 -3.3803424f-6 -6.0580893f-5 -0.00010181135 0.00013676415 -0.00016925731 -0.0001059761 1.862227f-5; -1.9793531f-5 -0.0002458613 6.762877f-5 0.00014997504 5.329153f-6 4.2531887f-5 0.00019949264 0.00012516824 0.00019401695 0.0001010834 5.2367253f-5 -1.9013878f-5 -4.310596f-6 5.7427093f-5 8.02568f-6 -2.1184182f-5 -0.00011093204 -0.0001201445 0.00013454811 -0.00011343138 -0.00010348 -6.8311165f-5 -6.590637f-5 3.332858f-5 -7.734974f-5 5.386471f-5 -4.4700195f-5 -0.00010770998 0.00014207611 9.052086f-5 -0.00012680642 -0.000109406516; -7.2990304f-5 -0.0001409586 -2.4206593f-5 0.00019625758 0.00016684129 1.977543f-5 0.00018552695 2.5862842f-5 -2.1988466f-5 1.7960763f-5 0.00012307346 -0.00016780203 0.00011685923 9.877229f-5 0.00012865127 -1.0434987f-5 -8.189504f-6 6.272357f-5 -8.4977563f-7 -0.00015581926 -2.8591483f-5 -2.0460075f-5 -0.00021083599 8.3116895f-5 -8.0299265f-5 0.00014697666 8.245075f-5 0.00022333564 -8.689001f-5 -4.6421726f-5 -0.0001421728 0.00022389524; 0.000120792225 -7.5462405f-5 -2.2761587f-5 8.2821396f-5 7.858781f-5 -0.00013338462 1.1886533f-5 -3.0358437f-5 9.597422f-5 -2.0737629f-5 1.904303f-5 -4.0594387f-5 -0.0001694644 -9.674283f-5 0.00021030848 6.0629845f-5 9.277687f-6 -0.00011269016 -9.705026f-5 4.0958854f-5 3.9712755f-5 8.559798f-5 0.00010305551 3.082235f-5 0.00012781554 -0.00016879484 0.00026652942 2.036087f-5 7.3917785f-5 9.526052f-5 3.7483187f-5 3.0688094f-5; -6.1845254f-5 -4.2489286f-5 0.00013306511 -2.9530222f-5 -1.4533242f-5 0.00015445294 -0.00012826746 3.9904407f-6 -3.194172f-5 -0.00012779448 -4.4395983f-6 -2.459271f-5 -0.000119960576 -0.00010586962 -9.8704964f-5 -0.00010831312 -0.00017142124 2.6619424f-5 9.952787f-5 0.0001034253 3.9260663f-5 -2.4238887f-5 -0.00013948296 1.821858f-5 -0.00010922454 -0.00015935225 2.1687552f-5 4.2896376f-5 9.9938174f-5 0.00015912039 8.125745f-5 0.00010357701; -7.640373f-5 1.5493759f-5 -0.00010201887 -4.4970075f-5 6.2906154f-5 5.014688f-5 -0.0001274386 3.1907635f-5 -8.947007f-5 4.176864f-5 5.835241f-5 -0.00013950825 0.00016382929 0.00025345126 -8.219709f-5 5.621984f-5 -1.9530667f-5 0.00016173259 0.00011156073 -8.3565494f-5 0.00012115905 0.00022112721 3.0175546f-5 2.2563f-5 8.2889805f-5 7.052733f-5 -1.2006735f-5 1.7645443f-5 -5.861451f-5 -0.00016247002 -0.00012410169 -0.00015581703; -8.591911f-5 1.607042f-5 0.00020168304 -9.39981f-5 4.5950652f-5 7.07546f-5 0.00027750104 -6.4676206f-5 2.5331048f-5 -1.7940674f-5 0.00010501167 -3.604744f-5 -8.72509f-5 8.260499f-5 -4.7712597f-6 -5.0223614f-5 -1.9702153f-5 -6.2802625f-5 -3.1014179f-6 -1.8627208f-5 -0.0003857484 -3.5760648f-5 -6.1169653f-6 0.00018726582 0.00017722118 -0.0001609782 -6.2837876f-7 8.706469f-6 1.20931945f-5 0.00013537536 0.00021942466 -0.00013438455; -4.3264445f-5 5.274219f-5 2.47064f-5 7.299642f-5 -7.083198f-5 1.7987924f-5 1.3055287f-5 1.46265165f-5 5.5161003f-5 -1.4492474f-5 5.4223833f-6 8.091846f-5 -0.00016131553 -0.00015378982 4.511301f-5 1.10076135f-5 1.8133f-5 0.00017641639 5.6631885f-5 -7.816738f-6 7.9563004f-5 -0.00011734862 6.456508f-5 5.0090333f-5 -0.00013036051 2.883334f-5 -2.1860551f-5 -0.00016749086 3.5188634f-5 0.00011088423 6.6028165f-6 -8.836865f-6; 3.3009146f-5 0.00011714363 -7.3376286f-5 -0.000114717106 8.123068f-5 -1.8369943f-5 0.000237848 -7.3159736f-6 -7.122988f-5 -6.755504f-5 3.0826533f-5 7.03882f-5 -4.33621f-5 -4.9402388f-5 -6.964949f-5 3.503605f-5 -4.883035f-5 -6.656603f-6 4.656099f-5 4.922405f-5 0.00012309104 -1.4890419f-6 3.680918f-5 0.00012425086 9.034085f-5 9.691348f-5 0.00011311517 0.000121131256 -0.00011680274 3.1623957f-5 -0.00024556986 1.2408618f-5; -0.0001307977 3.402081f-5 -0.00017316277 -0.0001090153 -0.00019153289 -7.679229f-6 0.00011021593 -0.00015973856 6.0890892f-5 -0.00012490578 7.5933854f-6 -4.3746036f-6 0.0001208226 6.697809f-5 -6.39576f-5 -2.4293374f-6 -0.00012214268 -1.4026753f-5 -4.8168804f-6 -5.818674f-5 -4.6004025f-5 -0.00019796574 -5.2663632f-5 4.8630816f-5 -8.087136f-6 5.657263f-5 7.3305935f-5 -7.844477f-6 -3.5451103f-5 7.2114555f-5 -7.167388f-6 -3.9850216f-5; 0.00032036274 0.00014830838 -0.00011518331 -0.0001737737 6.940856f-5 0.00011050573 -0.00019535642 -2.0748248f-6 -8.466353f-5 9.493442f-7 0.00020249424 -6.134204f-5 3.1714248f-5 0.00012622819 -1.1363703f-5 -1.4896446f-5 0.00014260215 9.358694f-5 -0.00016153688 -0.00021008613 0.00017355198 4.2717886f-5 -1.962473f-5 9.78363f-6 5.8102065f-5 9.033535f-5 5.4220105f-5 -5.6332858f-5 5.8226015f-5 1.626015f-5 -2.7879285f-8 4.8927915f-5; 2.39549f-5 4.5243272f-5 2.176681f-5 6.122727f-5 -0.00010887865 -1.8297698f-5 -5.5414694f-5 0.00015014755 3.8435064f-6 -0.00010917717 -0.00010001055 -3.6780388f-5 -4.3899414f-5 -1.2734123f-5 1.1320081f-5 1.5735059f-5 8.255435f-5 0.00013562046 0.000105540756 -2.1283371f-5 2.9844745f-5 -2.82065f-5 -3.354192f-5 -2.1515758f-5 -0.00012508243 -2.1877875f-5 2.5843749f-5 -0.00018540063 3.1956173f-5 -1.6471624f-5 -1.0087316f-5 1.0099328f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[-7.580879f-5 -1.26180175f-5 -0.00010368026 5.2604373f-5 -5.0258226f-5 3.0480553f-5 -3.88889f-5 0.0001659476 0.0002058711 7.2801726f-5 -7.93586f-5 -6.1494655f-5 0.000118986645 -3.0455203f-5 -0.00012197362 -1.0076472f-5 -0.00014221018 -5.230134f-5 -5.1141593f-5 -0.00013649564 7.216723f-5 -0.0003202819 -0.00010056202 0.00014149101 -0.00020836512 -8.649498f-5 0.00014506685 -7.663297f-5 0.00011285033 -3.1712316f-5 4.4604112f-5 -1.4948189f-5; 2.2432901f-5 3.6835623f-5 0.000174449 6.977765f-5 -3.6323418f-5 -5.2063442f-5 -6.320214f-5 -3.7064383f-5 0.00017938192 -4.6872403f-5 8.702507f-5 5.5398472f-5 -0.00017713325 4.536835f-5 -5.57951f-5 5.016454f-5 6.8129084f-5 -7.5831063f-6 -4.8039587f-5 -0.00013074097 -0.00020246024 7.361607f-5 6.0416925f-5 7.7730976f-5 4.934469f-6 1.5887934f-5 0.00013034108 -0.00011123116 -4.7659745f-5 0.000113766 -8.892609f-5 7.26999f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))Similar to most deep learning frameworks, Lux defaults to using Float32. However, in this case we need Float64
const params = ComponentArray(f64(ps))
const nn_model = StatefulLuxLayer(nn, nothing, st)StatefulLuxLayer{Val{true}()}(
Chain(
layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
layer_2 = Dense(1 => 32, cos), # 64 parameters
layer_3 = Dense(32 => 32, cos), # 1_056 parameters
layer_4 = Dense(32 => 2), # 66 parameters
),
) # Total: 1_186 parameters,
# plus 0 states.Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses
where,
function ODE_model(u, nn_params, t)
χ, ϕ = u
p, M, e = ode_model_params
# In this example we know that `st` is am empty NamedTuple hence we can safely ignore
# it, however, in general, we should use `st` to store the state of the neural network.
y = 1 .+ nn_model([first(u)], nn_params)
numer = (1 + e * cos(χ))^2
denom = M * (p^(3 / 2))
χ̇ = (numer / denom) * y[1]
ϕ̇ = (numer / denom) * y[2]
return [χ̇, ϕ̇]
endLet us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
axislegend(
ax,
[[l1, s1], [l2, s2]],
["Waveform Data", "Waveform Neural Net (Untrained)"];
position=:lb,
)
fig
endSetting Up for Training the Neural Network
Next, we define the objective (loss) function to be minimized when training the neural differential equations.
const mseloss = MSELoss()
function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
return mseloss(pred_waveform, waveform)
endWarmup the loss function
loss(params)0.000695159446876938Now let us define a callback function to store the loss over time
const losses = Float64[]
function callback(θ, l)
push!(losses, l)
@printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
return false
endTraining the Neural Network
Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
optprob,
BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
callback,
maxiters=1000,
)retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-0.00029180289129711535; 0.00016396421415244567; -0.00015462728333658055; -0.0001412374986101823; 2.155355832652848e-5; -8.221679308918412e-5; 1.6669822798548044e-5; -0.00021235232998131613; -0.00021415077208074916; 4.339118822822159e-5; 0.00010623952402947274; -0.00013206031871969862; -5.365259130486651e-5; 8.104361040748873e-5; 9.736525680619991e-5; -6.86032508383338e-5; -3.5452785596102425e-7; -9.231755393545338e-5; 2.4067065169212632e-5; 0.00011212930985497247; 8.15858293207503e-5; -9.260551450990443e-6; -7.644977449670603e-5; 2.5526609533659636e-5; -8.425935811819239e-5; -3.136982195425716e-5; -1.9138069546864798e-6; 0.00010249789193030607; -9.143747593046195e-5; 6.349011982812371e-6; -0.0001733742683426586; 0.0002509967598593924;;], bias = [-1.6898122266519217e-16, 2.362803037147358e-16, -1.3909580648041273e-17, -1.0901702669686893e-16, 3.3468148697536897e-17, 1.1125462965288716e-17, 3.122909213643224e-17, -2.1927102534360175e-16, -5.613485843137787e-18, 5.46481558826632e-17, 2.1020533822166556e-16, 1.1372921131259724e-16, 1.3616484829083901e-17, 1.7501955337177332e-17, 1.895507424501439e-16, -2.0752077627218575e-17, -3.7336513402574545e-19, -1.3378195470007047e-16, 7.905031921755135e-18, -4.4174364076969744e-17, 8.745393534982515e-17, -1.8186747366504138e-17, -1.828250506685702e-17, 4.38866443476286e-17, -1.0858627807329214e-16, -3.2900855053120614e-17, -3.6957107794042865e-18, 1.0085793078112123e-16, 2.9938022126840804e-17, 8.74180753920952e-18, -9.289972387276974e-17, 5.064087082766631e-16]), layer_3 = (weight = [-3.459702920484841e-5 -9.511721052378091e-6 0.00018080215168333577 -0.0001704282293733315 0.00010779990780314786 5.7241115521486253e-5 -0.0001534484871601837 0.00011164212274047998 -0.000180693499621381 0.00013770055748204153 3.395917529155172e-5 -5.8426572556758284e-5 0.00019652139244748687 -7.74299295674215e-5 -0.0001911127873402219 -6.179479159732831e-5 5.990239608528733e-6 -9.063828563063125e-6 5.283411528829843e-5 -2.261709223353602e-5 0.00016521646839714175 0.00010787702567790049 -5.210227018074585e-5 -0.00014179272320439683 -1.1983923193162357e-5 0.0001371277358909919 -5.6960205163030245e-5 -4.530322976029824e-5 -4.722793518999987e-5 2.3851509561374503e-5 -0.00011203113890337652 -7.826002901969542e-5; -9.075116491748394e-6 -0.00011014860875189269 -0.0001395365580775006 6.401723284757273e-5 -0.00012878235802967126 4.258756225131932e-5 -1.8355151138198642e-5 -0.00024089291774186586 3.691464725664486e-5 0.00011573422297621029 2.440282686980611e-5 8.432027597696268e-5 5.15204336752165e-5 -0.00014253984054998003 0.00014375586387719322 9.819078417387005e-6 0.000132860437492095 0.0001245035490628275 0.00012131639586547361 -0.00012035829068376601 0.00011180733044417132 0.00023625627717953178 4.647004961403628e-5 -1.4676901228409158e-6 3.718868528627315e-5 1.8930144409718362e-5 2.747709590007539e-6 3.135884607577921e-5 -6.942151604202206e-5 9.974816557038582e-5 -3.9518072395419666e-5 6.023787127007158e-5; 0.00011096256631669998 -0.00018745493442746868 -2.3133961883097797e-5 2.3535923564874645e-6 0.00018331249876881014 -2.3250155288235307e-5 -0.0001705527248186807 -7.91309358528166e-5 -1.0888844451044614e-5 -0.0001730728837055885 0.0001154214477655776 0.00013101281357467073 -0.00019418382734058174 -7.962708047682638e-6 -1.3119023685257407e-5 8.82477033091939e-5 3.3394827941991997e-6 -4.512291715128445e-5 0.0001506795232789778 -0.00020149687778667224 8.339214392707392e-5 -4.571260077411353e-5 7.264778824991172e-5 3.2682601079596364e-5 -7.440696216418653e-5 -3.0195153237417772e-5 -0.00013916556789537944 -0.00011492289203216215 -6.029940421923179e-5 5.7967466921563705e-5 -0.0002561337606375543 -7.250144707263475e-5; -0.00015928562160132424 1.1529746108675222e-5 -0.00012295619770903398 1.4253027904297426e-5 8.271545713049208e-5 8.585872902748246e-5 -3.989002751788284e-5 -7.453338283830769e-5 0.00017898710840363568 0.00010892791661879876 -0.00022269079007693343 -8.599501696900019e-5 5.8037562586636524e-5 -0.00019867136969716227 6.718176785277718e-7 0.00013247368670980445 0.0001196841393828282 -9.812621382317377e-5 2.556351220099107e-5 6.847223167986884e-5 2.1387610666221654e-6 -1.8659796376793448e-5 -0.0001395353111440405 5.161474666041022e-6 -0.0002946755349405195 1.9755378991052093e-5 0.0002110833947079146 -9.583379968554515e-5 -3.661408580975482e-5 -3.237820518572051e-5 0.00018019007157981352 2.0334938118874072e-5; 6.139028098252543e-5 0.0001931870135857029 0.00010950437697174889 3.761049426685501e-5 -8.34374098045763e-6 2.365442539742371e-5 -3.371569275910954e-5 4.788842282119793e-6 6.119236038355214e-5 5.861527076634395e-5 8.66097032238224e-5 -0.00013331491957359226 1.5300310927007095e-5 9.123392353360719e-5 0.00029954895560224367 0.0001355804049788549 7.755856474690182e-5 1.740089167310492e-6 -8.685000668916876e-5 -7.509648296689236e-5 -8.30018836742716e-5 -0.00012513435670113152 5.942704208139596e-5 -1.817796349056565e-5 1.965430580896904e-5 6.609233221100568e-5 8.332240374587519e-5 -9.35122372786079e-5 6.625400849611798e-6 -3.985876194489149e-5 0.00011875199540808914 0.00011348113164126711; 4.58802919094189e-6 6.870932742462838e-6 -4.225513344859665e-6 -1.4950716427844462e-5 -0.0001450166669710602 0.00011296097404192617 -0.00011599587701857723 7.455240888158038e-5 4.00742359888759e-5 0.00012237164668840687 -5.632057541879327e-5 -2.739803999793107e-5 -2.2885854883497853e-5 -1.8643169420101284e-5 3.3412703891373226e-5 -2.292843560644513e-5 -0.0001911359791447472 -0.0001376861687785994 0.00012264158471589283 2.8245820834856094e-5 -0.0001269228446799382 2.4045108551944754e-5 -3.9598016479430473e-5 0.00018763803422552843 -0.0001300780345954983 0.00023284342791521305 -4.6318258278916455e-5 -4.819811106790306e-5 -7.421536844048163e-6 -4.468602815541145e-5 3.26072262796564e-5 2.6384887354663e-5; 0.00019794417515432263 -0.00010020197602654737 0.00015737448323550578 -3.094239015565483e-5 -0.00010872109648274399 7.860233948364987e-5 -0.00011579770013425648 -0.00011793169484673139 -6.611746145892333e-6 1.9154700778624186e-5 -5.90231263843891e-5 -3.5749193345805274e-5 6.509392754439228e-5 -7.847032044749582e-5 -1.971644066300459e-5 1.9109946363340563e-5 -1.197431790145275e-5 7.213944102902173e-5 -3.918204107988762e-5 -7.749723387617236e-5 -2.6727989977489813e-5 -2.5327397240592806e-5 3.5225059333871957e-6 4.809887671641872e-5 -0.0001445199575458266 -5.3868397815070375e-5 2.3549377358600755e-5 4.6304720156214355e-5 7.870541069921307e-5 -0.00016158359155027356 -6.511728658308066e-5 0.00010404073166928028; -1.4946690023834899e-5 4.11064552482374e-5 -4.964459592693501e-5 1.2282474492256411e-5 1.6759500360529624e-5 -9.52207287608764e-5 -9.341490156464844e-5 -9.736646686899639e-5 0.00012049510907861168 5.776430182060439e-5 4.501401192292823e-5 -0.00011462421614660283 -9.314056158280812e-5 0.00015411800770572158 -9.21046874090373e-6 -3.994967344432193e-5 0.00012417712198122796 0.00011803309236903259 1.3830733698449204e-5 1.5583390924899273e-5 7.770045114931412e-5 -1.4794578853872029e-5 4.209611645102338e-5 0.00033426824440732004 -0.00012016913613287083 8.293336352118588e-5 -5.740712229260784e-5 4.041873537253914e-5 0.00011313269853609214 7.46556156828945e-5 -1.8109304144593363e-5 0.00011219483487552963; 1.3741947706344e-5 6.691617117348475e-5 2.0628040412247875e-5 9.217520372795229e-5 0.0003665634688254906 1.1250678009240359e-5 0.00015410848963395667 0.00017168643197826023 5.332113895880808e-5 -9.353668038645058e-5 0.00011692484648985305 -0.00011818090907444572 -2.2801453913844954e-5 -6.240427599983578e-6 8.610176545655099e-5 -0.0001324029469364847 -0.0001792861182443331 4.1803687319335106e-5 0.00015554362861760828 0.00012416908820479463 0.00016793980857212197 0.00016096881544538772 -1.2758485089284678e-5 5.9616297486656524e-5 -0.00014541548781242178 -4.319825867102599e-5 4.0441889988444806e-5 8.323363389674188e-5 -0.00022362340287916479 8.185916184769013e-5 -9.049912087767437e-5 3.268643585347093e-5; 9.563969939782404e-5 -2.9973304149405835e-5 -2.2200245835065975e-5 0.00010603678089397001 3.3615547695117357e-5 2.8185750714741075e-6 -4.239121920832976e-5 -9.597367028174958e-5 6.458543936321533e-5 5.813024607427008e-7 -3.195867469471613e-5 6.239106883052958e-6 -2.3947248732982e-5 -2.0751129736849913e-5 2.189164257729095e-5 -8.804416432681711e-5 3.667093693806227e-5 2.081766030173347e-5 3.850130593931712e-5 -4.935496548127951e-5 -1.6520840686239113e-5 4.387125559187503e-6 8.300673067864831e-5 -4.0050623679380436e-5 -7.821437469388399e-5 3.9184143647472114e-5 -0.00015528414081487435 5.864699557136954e-5 -0.00026285070492590727 -0.00018987980587113277 -9.929684750671713e-5 2.2060435699032414e-5; -0.0001708296230661945 9.657640755008212e-5 0.00010298214608169496 -0.0001258150340889627 -3.091598130580789e-5 0.0001272388572671448 3.7496210161551825e-5 9.838847024197326e-5 5.809040073485654e-5 2.1109964173009086e-5 -4.286005508206031e-5 -0.00012939973833440255 5.143209552457656e-5 3.978831507097908e-5 -3.2636916266569186e-5 1.1534793038806343e-5 -3.8982624509882095e-5 0.0001336638770864336 1.6380307961443336e-5 7.48356497480512e-5 4.9994151659263045e-5 6.764964479785605e-5 -4.172212805300949e-5 -6.368108454793518e-5 3.0583624784554884e-5 0.00010481566557260393 -1.1402630598096088e-5 -5.298170371902301e-6 -0.00013618813402885654 -3.9721447073899136e-5 -5.391397365188094e-5 8.268180233256865e-5; -2.5265516643992635e-5 -4.236426260086694e-5 6.116956381622319e-6 0.00011792267529762675 -8.695544905258071e-5 5.1591178515264644e-5 0.00016317394390007764 6.399946202525995e-5 0.00024367301436375917 4.963271902127642e-6 9.514047330248885e-5 0.0001834435831080514 -0.00015878127244995743 -0.0002439779776357673 0.00011017287094382433 -5.6637992503781926e-5 -0.0001690316124329873 -3.077525191820317e-5 9.411428679250446e-5 -3.070781834303674e-5 -5.314635139420636e-5 5.25590082072123e-5 -6.679972223700925e-5 3.586686187376342e-6 -1.1649073105035619e-5 0.00010431990146408106 8.49095965060807e-5 -4.454392486131931e-5 8.812615912411277e-5 0.0001893283194187499 -8.326668406135532e-5 -0.00017087285171494572; 4.0289004149200864e-7 7.589983261286345e-5 -8.145950732423814e-5 0.00016748247117365878 -8.202144723236533e-6 0.00016006341002512163 4.538671464988502e-5 7.415780101300002e-5 -3.1405542809723346e-5 1.5384316160725242e-5 -0.00019276737611644598 -5.908758566502273e-5 -4.337205200126894e-5 -0.0001480958921964895 -7.196756223120396e-5 -0.00011189827955231598 -7.408340342944945e-5 8.206884058754304e-5 0.00012624275087009648 0.00022497896543800181 3.0080309901767422e-5 -6.219298984060522e-5 7.723598631945572e-6 5.508251388537232e-5 -0.00010730277017128285 9.292069126249752e-5 -6.573957534410516e-5 0.00019488394582740343 1.4727561851121698e-5 -3.4185846285163325e-5 -0.00013003080993794823 -1.987073619230674e-5; 1.135471854515414e-5 1.2381890037835725e-5 -3.308245498334119e-5 -3.7548343179482736e-5 6.307553491788846e-5 0.00013665587545839822 -5.3304447582914474e-5 0.00012580491972627726 8.91467535170424e-5 -4.6369418028813575e-6 2.5147270271576144e-5 -0.00019619546576865848 -1.3465126293874485e-5 -0.00011996243477331745 1.327325578169741e-5 -3.856036703799414e-5 -7.44454336468981e-5 -1.4021742508320978e-5 1.5081510086107541e-6 8.268478638494834e-5 -0.00011862930471541621 -7.386388090671154e-5 -3.535984790649623e-5 -0.00021682033177929975 -6.90551786916165e-5 -6.513797045917259e-5 -9.05584565130913e-5 -0.00013057824610730436 0.00010699272031199131 -5.5468812142490674e-5 -1.8630125786832315e-5 -0.00021905114038373388; -0.00010251813018760627 -3.330164345169056e-5 3.8573765596823085e-5 -3.033607746975398e-7 3.2674353679603636e-5 -0.00012664954210704678 2.273643422157834e-5 -8.59674948507439e-5 6.674923843168873e-5 -6.102901188497422e-5 -0.00011214706701066427 7.66212213427993e-5 -4.204467786549493e-5 0.00019428407158629467 7.612631798183304e-5 -9.800064817609937e-5 5.2878149109155196e-5 -8.074103625331401e-5 -6.369193233623276e-5 2.2407591860248122e-5 -8.75594305326028e-6 -0.00010903514721504566 0.000151425770855708 -0.00010289887377367455 -2.147995200360794e-6 -0.00013404175497199478 1.3411596956565774e-5 -4.220019054555485e-5 -1.1424369585440664e-5 -5.630271716582746e-5 -4.7913050547486925e-5 -0.0001915193026426538; 3.195054237261027e-5 3.4408626427169184e-5 3.221388838253465e-5 -3.3862269581297696e-5 0.0001802142731798908 -1.7617953878938067e-5 -0.00011596559411146646 0.0001271564623061396 6.0297233042056724e-5 0.00017630385341589277 8.154704196611722e-5 6.652099623577605e-5 -0.00018837923484853593 2.844767627947152e-5 1.0522517961803282e-5 -6.819644399685089e-5 0.00013185984503478795 0.00016359473452438137 5.694759315928457e-5 -7.112225207267103e-5 -2.6366123299807096e-5 0.0001481321604410284 0.0001458930863486666 6.6915507454635455e-6 -0.00012756088758352592 -8.333837390410724e-5 5.265783406911465e-5 -0.00023034452819730735 -0.00012690821963361917 -0.00014323168502188238 0.00016107563792727069 4.060095020799902e-5; -1.3392691859829428e-5 5.627016941883993e-5 -7.449930241871502e-6 -0.00011715754810852996 1.0724106298730197e-5 0.0002209620084367589 0.00014368726403787133 -3.303226769919419e-5 2.364751669385973e-5 2.1184464979304388e-5 5.202366953467537e-5 -0.00012500990708490733 -6.634072490400478e-5 -1.1654668403066223e-5 -8.617708811202377e-5 -8.992381338161382e-5 0.00017250285015776173 7.356258863451663e-5 0.0001629228878054706 -0.00010304628670114593 -0.00010684783631135807 -0.00011246813521851439 4.2169821947896426e-5 -3.346743544811074e-5 0.00011611842053112445 -3.281490210341935e-5 5.6971251590730414e-5 -8.269445099999557e-5 5.980894054387482e-5 0.00018310504089909954 5.88880698823574e-5 0.00014985646118740353; -0.00014548344623095346 -5.3575510519135803e-5 -0.0001026433281184141 -2.55830352942406e-5 -0.00018285825060939754 -7.112443998249909e-5 0.00018319749071469204 -2.1847784209817668e-5 5.3738417791613906e-5 2.840152292436656e-6 4.977673890624237e-5 0.00014603149176780798 0.00022474639556111309 -0.0001672083788251383 9.436333786669323e-5 0.00010611213949249505 -2.2442455501559546e-5 -0.00012473138192928217 -9.611236036844733e-6 0.0001429626529172154 -1.157301680735968e-6 0.0002254939274463957 4.9733956275473744e-5 5.8951670526029786e-5 0.00014096720698961315 2.3553603769132014e-5 -6.862497484517173e-6 -3.35536758755698e-5 2.9586727778391965e-5 -9.94216650540786e-5 -6.704393015736696e-5 3.395694076985325e-5; -8.810079809371553e-5 3.783430300212637e-5 0.00012589426849412088 1.54073951122348e-5 7.752641620722842e-5 0.00011444593481396329 2.3998642136256836e-5 5.6458153339581986e-5 -1.3730487735637799e-5 0.00011048918175188082 -9.048914492769883e-7 -0.00011434612308253937 -8.341083950686117e-5 0.0001439153603129369 4.561296380383794e-5 3.62517894970075e-5 7.2291357427866e-5 -1.5211429781951642e-5 6.016448613049702e-5 -0.0001406793270874511 0.00018423551598939115 -0.00010696945893901144 5.764118402989619e-5 -5.847327331171135e-5 -0.0003003021956104506 5.421249814357831e-5 -7.928623737857418e-5 -1.0042583126373573e-5 0.0001178608182086883 0.00013125426274168564 2.5592795354566128e-5 -0.0001594381282487817; 3.7045891525704574e-5 9.19056355434164e-6 -0.00012973351431918534 -9.187297428633457e-5 -3.7337772943100893e-5 3.917584080573163e-5 5.328863744016449e-5 -0.0001226316703707002 -5.9286444871921095e-5 9.461524722955902e-5 3.8102033153159534e-5 -3.768205670548152e-5 3.879878613024774e-5 0.00031556708244698104 -0.0001795073799314153 -5.8546530914279674e-5 1.156047300701195e-5 -7.17680627587431e-5 7.676963695648781e-5 -2.373750567693953e-5 0.00011677969859538585 2.261445706044493e-5 7.470042737059247e-5 -0.00010225845063014229 -0.00012219204245975235 -5.39985771922141e-5 -0.00018572539875658532 -8.924665513581936e-6 -0.00010648680063804552 3.5280322202925386e-5 5.2491573058564055e-6 5.509355877969034e-5; -0.00010724014451449391 -5.223378126002675e-5 3.9073276768149547e-5 -6.228638254298189e-5 -6.453406409321083e-5 0.0001707454340629535 -6.0575757614176014e-5 -9.488843521293725e-5 0.00024170329392464413 -6.386291586864823e-6 -2.8202300980518885e-5 -7.605808956042392e-5 -2.7430214557284838e-5 7.314161315832963e-5 0.00011886103142558537 -4.459016586963466e-5 -1.4374400832937046e-5 -6.190799636726391e-5 -0.00017062308493962232 -1.0223687111754823e-5 -9.998877963709894e-5 -1.1053835879713354e-6 8.414943424101688e-6 0.00021875127107875706 -0.00015549957048108086 -3.382424991401947e-6 -6.058297536412947e-5 -0.00010181342973566605 0.0001367620699437601 -0.00016925939585974199 -0.00010597818059805393 1.862018702187836e-5; -1.9793190877038542e-5 -0.00024586096107676584 6.762910787066524e-5 0.00014997537662558383 5.3294934149714405e-6 4.253222730759803e-5 0.00019949298301063823 0.00012516858121663175 0.00019401728648509665 0.00010108374270195469 5.2367593962040144e-5 -1.9013537457872528e-5 -4.3102554810353245e-6 5.7427433130138496e-5 8.026020150205668e-6 -2.1183841017849597e-5 -0.00011093169589010651 -0.00012014415689926342 0.00013454845473453742 -0.00011343103826228321 -0.00010347966010166059 -6.83108242083865e-5 -6.590603290540815e-5 3.33289194531214e-5 -7.734939710652188e-5 5.386505152265878e-5 -4.469985426120113e-5 -0.00010770963992407316 0.00014207645152047532 9.052119865365674e-5 -0.00012680608189847312 -0.0001094061749609782; -7.298785146220785e-5 -0.0001409561440314372 -2.4204139998829496e-5 0.00019626002961647127 0.00016684374043596966 1.977788327256701e-5 0.0001855294036713436 2.5865295020548878e-5 -2.198601338845994e-5 1.796321578784787e-5 0.00012307591589977152 -0.00016779958183999545 0.0001168616805725717 9.877474133735092e-5 0.00012865372321446689 -1.043253484478039e-5 -8.18705154325488e-6 6.272602388979278e-5 -8.473230595979622e-7 -0.0001558168072451929 -2.8589030415802005e-5 -2.045762209400915e-5 -0.00021083353873929385 8.31193471012641e-5 -8.029681282006678e-5 0.00014697911013262363 8.245320407785568e-5 0.00022333808795427156 -8.688755712593551e-5 -4.64192738663862e-5 -0.00014217034128617947 0.00022389769640625168; 0.00012079475018715229 -7.545988009485575e-5 -2.2759062262479318e-5 8.282392118883726e-5 7.859033249146438e-5 -0.00013338209063918708 1.1889057900723149e-5 -3.0355911435338355e-5 9.597674252755163e-5 -2.0735103675955743e-5 1.9045554295314422e-5 -4.059186139317551e-5 -0.00016946187642575305 -9.674030261054068e-5 0.00021031100950962266 6.063237051889086e-5 9.280212366454611e-6 -0.00011268763736994529 -9.70477336476079e-5 4.09613789135039e-5 3.971528022260964e-5 8.560050664995961e-5 0.00010305803189580024 3.0824874925335186e-5 0.0001278180647251621 -0.00016879231370226392 0.0002665319410927145 2.0363395747279404e-5 7.392031002849761e-5 9.526304384521377e-5 3.748571216885513e-5 3.0690618955387404e-5; -6.184671883694658e-5 -4.249075073043646e-5 0.0001330636491720288 -2.9531686323980553e-5 -1.4534707134826343e-5 0.00015445147209806507 -0.00012826892862091013 3.98897600367777e-6 -3.194318424087801e-5 -0.00012779594772024177 -4.4410630844637625e-6 -2.4594174193272388e-5 -0.0001199620405723718 -0.00010587108636269217 -9.870642907213152e-5 -0.00010831458576017952 -0.0001714227059902083 2.6617959064373182e-5 9.952640311797923e-5 0.00010342383529764083 3.925919872830755e-5 -2.4240351650460573e-5 -0.00013948442438146 1.821711475617098e-5 -0.00010922600131476144 -0.0001593537112976811 2.168608761980013e-5 4.289491107460429e-5 9.993670891976198e-5 0.00015911892614798437 8.125598605961862e-5 0.00010357554628987835; -7.640266618325207e-5 1.5494824891401002e-5 -0.00010201780593120062 -4.4969008614815684e-5 6.290722025228978e-5 5.0147944546966825e-5 -0.00012743753528551813 3.1908701328976766e-5 -8.946900548977177e-5 4.1769704820087954e-5 5.8353474867957946e-5 -0.00013950718481423608 0.00016383035484837598 0.0002534523265589597 -8.219602374223993e-5 5.6220906243158296e-5 -1.952960143311552e-5 0.00016173365669461 0.00011156179620089319 -8.35644276426557e-5 0.00012116011579424241 0.00022112827367751252 3.0176612325209523e-5 2.2564065273816967e-5 8.289087067028591e-5 7.052839295143556e-5 -1.2005669178999005e-5 1.764650927585048e-5 -5.861344458166471e-5 -0.0001624689574432899 -0.00012410062105259643 -0.00015581596732005472; -8.591800499910532e-5 1.6071528505735976e-5 0.00020168414854942178 -9.399698822726997e-5 4.595176002858404e-5 7.075570690012507e-5 0.00027750214437055604 -6.46750973208482e-5 2.5332156354291033e-5 -1.7939566106906563e-5 0.00010501277799284828 -3.604633232662865e-5 -8.724978900481835e-5 8.260609544613375e-5 -4.770151492087265e-6 -5.0222506189858084e-5 -1.970104452247983e-5 -6.280151640731002e-5 -3.1003096699346085e-6 -1.8626100001533087e-5 -0.0003857472889028958 -3.575953954332297e-5 -6.115857133228519e-6 0.00018726692584841704 0.00017722229173905177 -0.00016097709661316877 -6.272705666553957e-7 8.707577279161898e-6 1.2094302717921398e-5 0.00013537646961431235 0.00021942577083626828 -0.00013438344242952012; -4.326399549770957e-5 5.2742639933266335e-5 2.4706849113708453e-5 7.299687230100082e-5 -7.083152933206897e-5 1.7988373552936855e-5 1.3055736313829938e-5 1.46269657181858e-5 5.51614520065651e-5 -1.449202510572547e-5 5.422832497271545e-6 8.091890584749365e-5 -0.00016131508356353794 -0.00015378937142829054 4.511346004276849e-5 1.1008062652974951e-5 1.8133448871807048e-5 0.0001764168390638879 5.6632334132003366e-5 -7.816288519978644e-6 7.956345315163377e-5 -0.00011734817427510878 6.45655289974752e-5 5.0090782575565145e-5 -0.0001303600603148586 2.8833788643381265e-5 -2.186010180027239e-5 -0.00016749040706942492 3.518908296085292e-5 0.00011088468262299415 6.603265670699895e-6 -8.83641597627197e-6; 3.301108779319122e-5 0.00011714557228896247 -7.337434465244744e-5 -0.00011471516438705933 8.123262421050308e-5 -1.8368001406814038e-5 0.00023784994016223804 -7.314031905320308e-6 -7.122793715627709e-5 -6.755309834788444e-5 3.0828475131965134e-5 7.039014496894623e-5 -4.336015729295039e-5 -4.940044630147698e-5 -6.96475482307504e-5 3.5037991323591496e-5 -4.8828406875866956e-5 -6.654661434185234e-6 4.656292992349907e-5 4.9225992255922424e-5 0.00012309298369855433 -1.4871002356752568e-6 3.681112036629445e-5 0.00012425280044608936 9.034279253155709e-5 9.691541970296931e-5 0.00011311711466460718 0.000121133197239214 -0.00011680079580928493 3.1625898258613836e-5 -0.0002455679142493082 1.2410559860803328e-5; -0.00013080100781082823 3.401750926705619e-5 -0.0001731660713748146 -0.00010901860242747907 -0.00019153619088040167 -7.682529951960237e-6 0.00011021262888523296 -0.00015974185681704727 6.088759163242552e-5 -0.00012490908285410428 7.59008479425363e-6 -4.377904145171716e-6 0.0001208193016164639 6.697479237762498e-5 -6.396090340141603e-5 -2.432637996134965e-6 -0.00012214597974245922 -1.4030053546699121e-5 -4.820180959472946e-6 -5.819003946536286e-5 -4.600732519331158e-5 -0.00019796903963008439 -5.26669328000196e-5 4.8627515740986485e-5 -8.090436506715626e-6 5.656932897825345e-5 7.330263450701234e-5 -7.84777786331074e-6 -3.545440359895637e-5 7.211125465535076e-5 -7.170688616397582e-6 -3.985351669403026e-5; 0.00032036531751389384 0.00014831095153815987 -0.00011518073757294373 -0.0001737711284088291 6.941113655807732e-5 0.00011050830630006023 -0.00019535384513551075 -2.072249351192097e-6 -8.466095483280487e-5 9.519196671612863e-7 0.00020249681232848514 -6.133946613523378e-5 3.171682305287649e-5 0.00012623076303748188 -1.136112792415083e-5 -1.4893870918285834e-5 0.000142604723571598 9.358951478384045e-5 -0.0001615343064779575 -0.000210083555664148 0.00017355455178653587 4.27204611176004e-5 -1.9622153792853282e-5 9.786205503363471e-6 5.81046403956386e-5 9.033792570585622e-5 5.422268053985517e-5 -5.6330282184040136e-5 5.822858997157398e-5 1.626272546014132e-5 -2.5303829651872564e-8 4.893049094125368e-5; 2.3954193839382365e-5 4.5242565338763973e-5 2.1766102382012524e-5 6.122656364858326e-5 -0.00010887935908578607 -1.8298405247713774e-5 -5.5415400620059076e-5 0.00015014684200686225 3.842799433542685e-6 -0.00010917787707478045 -0.00010001125459859168 -3.6781094731501044e-5 -4.3900120853886296e-5 -1.2734829706113234e-5 1.1319374298529524e-5 1.5734351709744088e-5 8.255363960450233e-5 0.00013561975234586659 0.00010554004862604394 -2.1284078083787976e-5 2.9844037872300125e-5 -2.820720632336848e-5 -3.354262784307038e-5 -2.151646489403012e-5 -0.00012508313497129947 -2.187858202853034e-5 2.584304153095183e-5 -0.00018540133290991228 3.1955466202254435e-5 -1.6472330880442654e-5 -1.0088022772481768e-5 1.0098621405991742e-5], bias = [-7.093124243546903e-10, 2.1062932909600024e-9, -3.172396664097522e-9, -1.0851074120201966e-9, 3.746911567153001e-9, -5.598449212061756e-10, -1.3367456027559307e-9, 2.9517851715969057e-9, 3.896540449999236e-9, -2.283227477277193e-9, 1.0141504016658332e-9, 1.3830396748561708e-9, 4.770336999600249e-10, -3.6779325374796238e-9, -2.754131138064877e-9, 1.8415926122167269e-9, 2.075078518863574e-9, 1.5672596326291647e-9, 1.1678636379429187e-9, -1.3274131559049225e-9, -2.0826138345921325e-9, 3.405875002309233e-10, 2.4525715978717967e-9, 2.525125771602753e-9, -1.4647421818148951e-9, 1.0660537178898878e-9, 1.1081903728640226e-9, 4.491869125854297e-10, 1.94165677993092e-9, -3.30056802855781e-9, 2.5754550317415987e-9, -7.069862662139731e-10]), layer_4 = (weight = [-0.0007500012003607034 -0.0006868103270920246 -0.0007778724109058173 -0.0006215880192638222 -0.0007244502913846811 -0.0006437118613089832 -0.0007130812750148795 -0.0005082446206375883 -0.0004683209651099115 -0.0006013905695234959 -0.0007535509953015446 -0.0007356870271018841 -0.0005552057710251482 -0.0007046472804759755 -0.0007961658466528801 -0.0006842688077713253 -0.0008164024847272758 -0.0007264936994726606 -0.0007253339798820379 -0.0008106880165880577 -0.0006020250852822088 -0.0009944743099207572 -0.0007747542857018684 -0.000532701261077426 -0.0008825574832786258 -0.0007606873755916946 -0.0005291255423385058 -0.0007508253850462414 -0.0005613420034245129 -0.0007059044576768909 -0.0006295881478662563 -0.0006891405982665653; 0.00024855678041969223 0.0002629594690392304 0.0004005727939517843 0.0002959015221626048 0.0001898003466673765 0.0001740604385965318 0.00016292172810596366 0.0001890594320668616 0.0004055056875506919 0.0001792514382015605 0.00031314894764429915 0.0002815223387107052 4.899063518907844e-5 0.00027149211795012435 0.00017032871851707982 0.00027628839553981566 0.0002942529289685582 0.0002185407561273581 0.00017808428520112372 9.538289872653837e-5 2.366360522858449e-5 0.0002997399549609901 0.0002865407561065287 0.000303854808217112 0.00023105833285462613 0.0002420118073089214 0.0003564649560959204 0.00011489272321534026 0.00017846410803689522 0.0003398897899729668 0.00013719773747466733 0.000298823777795369], bias = [-0.0006741924216124024, 0.00022612388329802]))Visualizing the Results
Let us now plot the loss over time
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")
lines!(ax, losses; linewidth=4, alpha=0.75)
scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)
fig
endFinally let us visualize the results
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
s3 = scatter!(
ax,
tsteps,
waveform_nn_trained;
marker=:circle,
alpha=0.5,
strokewidth=2,
markersize=12,
)
axislegend(
ax,
[[l1, s1], [l2, s2], [l3, s3]],
["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
position=:lb,
)
fig
endAppendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.12.6
Commit 15346901f00 (2026-04-09 19:20 UTC)
Build Info:
Official https://julialang.org release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 4 × AMD EPYC 9V74 80-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-18.1.7 (ORCJIT, znver4)
GC: Built with stock GC
Threads: 4 default, 1 interactive, 4 GC (on 4 virtual cores)
Environment:
JULIA_DEBUG = Literate
LD_LIBRARY_PATH =
JULIA_NUM_THREADS = 4
JULIA_CPU_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0This page was generated using Literate.jl.