Training a Neural ODE to Model Gravitational Waveforms
This code is adapted from Astroinformatics/ScientificMachineLearning
The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl
Package Imports
using Lux,
ComponentArrays,
LineSearches,
OrdinaryDiffEqLowOrderRK,
Optimization,
OptimizationOptimJL,
Printf,
Random,
SciMLSensitivity
using CairoMakie
Define some Utility Functions
Tip
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector
function one2two(path, m₁, m₂)
M = m₁ + m₂
r₁ = m₂ / M .* path
r₂ = -m₁ / M .* path
return r₁, r₂
end
Next we define a function to perform the change of variables:
@views function soln2orbit(soln, model_params=nothing)
@assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
if size(soln, 1) == 2
χ = soln[1, :]
ϕ = soln[2, :]
@assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
p, M, e = model_params
else
χ = soln[1, :]
ϕ = soln[2, :]
p = soln[3, :]
e = soln[4, :]
end
r = p ./ (1 .+ e .* cos.(χ))
x = r .* cos.(ϕ)
y = r .* sin.(ϕ)
orbit = vcat(x', y')
return orbit
end
This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d_dt(v::AbstractVector, dt)
a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
b = (v[3:end] .- v[1:(end - 2)]) / 2
c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
return [a; b; c] / dt
end
This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d2_dt2(v::AbstractVector, dt)
a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
return [a; b; c] / (dt^2)
end
Now we define a function to compute the trace-free moment tensor from the orbit
function orbit2tensor(orbit, component, mass=1)
x = orbit[1, :]
y = orbit[2, :]
Ixx = x .^ 2
Iyy = y .^ 2
Ixy = x .* y
trace = Ixx .+ Iyy
if component[1] == 1 && component[2] == 1
tmp = Ixx .- trace ./ 3
elseif component[1] == 2 && component[2] == 2
tmp = Iyy .- trace ./ 3
else
tmp = Ixy
end
return mass .* tmp
end
function h_22_quadrupole_components(dt, orbit, component, mass=1)
mtensor = orbit2tensor(orbit, component, mass)
mtensor_ddot = d2_dt2(mtensor, dt)
return 2 * mtensor_ddot
end
function h_22_quadrupole(dt, orbit, mass=1)
h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
return h11, h12, h22
end
function h_22_strain_one_body(dt::T, orbit) where {T}
h11, h12, h22 = h_22_quadrupole(dt, orbit)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
h11 = h11_1 + h11_2
h12 = h12_1 + h12_2
h22 = h22_1 + h22_2
return h11, h12, h22
end
function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
# compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2
@assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"
h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
@assert mass_ratio ≤ 1 "mass_ratio must be <= 1"
@assert mass_ratio ≥ 0 "mass_ratio must be non-negative"
orbit = soln2orbit(soln, model_params)
if mass_ratio > 0
m₂ = inv(T(1) + mass_ratio)
m₁ = mass_ratio * m₂
orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
return waveform
end
Simulating the True Model
RelativisticOrbitModel
defines system of odes which describes motion of point like particle in Schwarzschild background, uses
where,
function RelativisticOrbitModel(u, (p, M, e), t)
χ, ϕ = u
numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
denom = sqrt((p - 2)^2 - 4 * e^2)
χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = numer / (M * (p^(3 / 2)) * denom)
return [χ̇, ϕ̇]
end
mass_ratio = 0.0 # test particle
u0 = Float64[π, 0.0] # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4) # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize) # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e
Let's simulate the true model and plot the results using OrdinaryDiffEq.jl
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)
axislegend(ax, [[l, s]], ["Waveform Data"])
fig
end
Defining a Neural Network Model
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model
that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but in case you do use them, make sure to mark them as const
.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl
,
const nn = Chain(
Base.Fix1(fast_activation, cos),
Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-8.9388064f-5; 4.5351503f-6; -0.00015214137; 6.114776f-5; -1.1209134f-5; 8.799054f-6; 0.00011623872; 3.05429f-5; -3.83482f-6; 0.00015138701; 2.1730368f-6; -5.228177f-5; 9.921482f-5; 4.3510307f-5; -0.00012580592; 3.8496064f-6; -0.000108541; 7.900059f-5; 0.00011266743; 6.339205f-5; -0.00012326815; -6.280487f-5; 0.00015057165; -0.00013831796; 4.2652668f-5; 7.461641f-5; -4.27946f-5; 1.3773994f-5; -5.564968f-5; 9.758318f-6; 7.780476f-5; -3.1738808f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[0.00011451244 -9.797832f-6 -1.8448734f-5 9.298047f-5 -0.000101742924 3.2806664f-5 -0.00020098341 -6.699487f-5 0.00011111761 0.00015806322 -0.0002444207 9.812995f-5 3.0465548f-5 0.00017377714 7.61732f-5 -0.00011140031 -4.0784817f-5 0.00014837775 0.00012752574 -0.00010261932 7.1626f-5 -6.206286f-5 0.00013178514 3.759823f-6 -0.00023320709 6.1027804f-5 -1.6187832f-5 3.891618f-5 -5.8271897f-5 -0.00019433236 7.4522395f-5 3.6655307f-5; -5.309447f-5 2.5810414f-5 -1.1221937f-7 0.00010407169 -0.00014479396 -4.9464652f-5 7.447467f-5 -5.1916973f-5 0.00018844358 0.000108912835 0.00010271982 -8.095074f-5 0.00012696919 -6.253552f-5 5.766943f-5 -4.49821f-6 1.4974889f-5 7.519112f-5 -0.00016329327 1.2466798f-5 0.0001649452 0.00011111219 1.1500285f-5 -0.00020547253 0.00018217484 -6.843421f-5 0.00018746815 0.00019399593 -4.504041f-5 2.4008135f-5 0.00018855676 -0.00016264507; -9.604787f-5 -7.176783f-5 -0.0001075844 3.934159f-5 0.00011421174 9.24785f-5 -1.4486345f-5 6.4913875f-5 8.990758f-5 0.00026085938 0.00021532897 3.5675006f-5 3.4372567f-7 5.959175f-5 -8.501883f-5 -0.00010264326 -0.000113325535 0.0002818491 0.00017741721 4.2463205f-5 1.3803634f-5 -5.6066303f-5 0.00012256048 -6.016216f-5 -0.00017384888 -9.481337f-5 1.5512056f-5 -8.695179f-5 4.9197457f-5 -1.2734439f-5 4.5890883f-5 4.61415f-5; 0.00019276525 -8.068402f-5 4.2769203f-5 -0.00011042768 -6.734364f-6 8.257294f-5 -6.0721606f-5 0.00012129415 -4.9118546f-5 -5.8978567f-5 -6.6479f-5 6.3245756f-7 -5.2841966f-5 8.4902706f-5 7.28772f-5 3.4881132f-6 -9.222479f-5 -0.00012041015 2.0756464f-5 -3.0037312f-5 -6.7468085f-5 6.50638f-5 -0.00010362366 -0.0001370196 0.00016208265 0.00013595226 -0.000113768045 9.020066f-5 0.00012746826 7.1696406f-5 -6.199798f-5 1.8654147f-5; 3.0187088f-5 -8.2569204f-5 1.630463f-5 0.00018114847 -1.3182372f-5 -9.1029535f-5 -2.034304f-5 -0.00023838333 2.2424629f-5 0.0001249759 0.00013387649 4.847546f-5 5.7377398f-5 -5.2304043f-5 6.600685f-5 -0.00018184219 -5.8643553f-5 -7.334203f-5 9.371442f-6 -0.0001508627 -9.431321f-5 -3.5020803f-5 1.6114805f-6 0.00014357969 4.5270684f-5 -6.522017f-5 -8.7969616f-5 -0.00014924032 -0.00013493194 0.00020919085 0.00015619917 -2.533285f-5; 1.245501f-6 -5.8053487f-5 1.0311617f-5 -8.571106f-5 -9.890795f-5 -7.806836f-5 5.116477f-5 0.00013933415 8.684726f-5 -0.00014101288 3.598061f-5 -1.5622922f-5 -0.00027564284 -3.5021807f-5 0.00017156362 5.614879f-5 -0.00012805013 -0.00013819509 3.4784272f-5 -5.5007466f-5 -2.1965214f-5 1.5279136f-5 0.00015468636 -1.4610767f-5 -6.494188f-5 -3.432421f-5 5.8900852f-5 2.4512945f-5 4.187701f-5 -6.745216f-5 8.099909f-5 -0.00026460743; -0.00012759445 -6.175785f-5 -0.00011273406 -0.00029633185 -3.0289075f-5 3.1951677f-5 -2.083552f-6 -0.0001666729 -4.2433956f-5 0.00014505614 -0.00014206332 4.715079f-5 -0.000119468285 -3.5891364f-5 -8.672946f-5 0.00016388687 6.79814f-5 0.00014028588 -7.6346205f-6 -5.176354f-6 -0.00021809309 -2.837664f-5 -4.0232815f-5 -8.091506f-5 7.7111036f-7 -0.00015540505 -4.97628f-6 2.5112678f-5 -3.501507f-5 -0.000113821225 0.00013189092 0.00016229371; -8.755021f-5 5.5652818f-6 -6.364144f-5 -0.00010722025 -0.000109981054 7.96973f-5 4.4435932f-5 -0.00020811321 -9.5522366f-5 -0.00012342629 -1.6395996f-5 0.00012153134 2.4608906f-5 4.8350546f-5 0.00018568466 2.5120926f-5 4.3319567f-5 -5.599086f-5 3.2481069f-7 7.138322f-5 -4.358374f-5 -9.787589f-5 -0.0001268871 -8.171554f-5 3.5883957f-5 9.467212f-5 0.00014990337 9.167681f-5 -6.360248f-5 0.0001005458 7.2274556f-6 1.3446176f-5; -5.3957283f-5 8.095025f-5 0.00011621817 6.970343f-5 -0.0001352713 -4.4810704f-6 -4.6316334f-5 0.00013344588 -0.00011283851 9.212718f-5 5.3059415f-5 9.319101f-6 -0.00012393328 4.784013f-5 -1.268433f-5 0.00020009356 -0.00015104808 2.7466453f-5 4.139534f-5 7.654623f-5 -8.91136f-5 -9.813408f-5 3.3218712f-5 0.00015736371 0.00018051024 -2.775125f-5 7.835374f-5 -4.958598f-5 -0.00011651152 -4.7175286f-5 -0.00011173756 3.8222697f-5; -1.5023104f-5 -3.5324534f-5 -0.00010319622 0.00018093767 8.164724f-5 -0.00016378713 -5.9146838f-5 0.00016479811 4.923082f-5 -4.7988113f-5 1.1801825f-5 -8.611829f-5 -0.000120305245 -0.0001300813 0.00018071475 -1.7266313f-5 -2.0582222f-6 -8.639165f-6 5.4479955f-5 -5.760427f-5 4.577293f-5 -0.000105143976 -2.2331993f-5 6.609562f-5 6.3382344f-5 6.4138076f-5 -9.7996f-5 0.00026974245 -8.8599845f-5 0.0001309458 0.00016414006 -1.8274482f-5; -0.0001109574 -7.3109586f-5 -7.6040415f-5 7.8192454f-5 0.00025282896 9.94007f-6 -4.4644994f-6 -0.00012892902 0.00018111315 4.129721f-5 -5.3253207f-5 8.521478f-6 2.2510762f-6 -8.0197016f-5 0.00014440097 -0.000110914894 -2.9317272f-5 7.741688f-6 -2.9623903f-5 2.0568286f-5 8.564952f-5 -0.00014406512 2.7458882f-5 0.00014535517 -0.00035004693 4.9594066f-5 6.684204f-5 -0.00018068583 0.00023588333 0.00013359997 0.00010278468 1.6597414f-5; 4.894526f-5 -0.00010558271 -4.3493612f-5 -6.1398205f-5 4.505317f-5 -1.09062785f-5 -0.000111177986 6.3364525f-5 0.0001266731 3.7373218f-6 -1.3166695f-6 5.469897f-5 0.00019612254 6.496494f-5 -8.577877f-5 2.0738207f-5 -3.004849f-5 4.297248f-7 -3.3935034f-5 -0.00013560959 4.701431f-5 2.0143174f-5 -0.00019741335 -3.9282797f-5 -0.00011196898 -0.00011016113 -0.000113413866 -3.8850147f-5 -0.00019757394 5.570136f-5 2.7646922f-5 -5.3435935f-5; -0.0001756195 -0.0001348344 9.610759f-5 2.1921871f-5 1.2461769f-5 -0.00015479933 9.201945f-5 6.240842f-5 0.000110804474 0.00014819467 -4.990117f-6 -1.2823267f-5 7.1083065f-5 7.439646f-5 0.00016061463 -7.1762566f-5 1.6463688f-5 -6.255691f-5 -0.0001165426 -3.1576994f-5 -0.00013487077 -5.190689f-5 -8.762451f-5 -7.8138455f-6 -2.1736048f-6 -1.5824846f-5 5.0756098f-5 1.5318135f-5 -2.3828161f-5 -2.3514496f-5 7.5795164f-5 -0.000120392986; 8.9044726f-5 -5.5627046f-5 -9.627247f-5 0.00021072343 1.635785f-5 -9.354788f-5 -8.9777786f-5 -5.7232348f-5 -7.0324177f-6 0.00011867894 -7.987537f-5 7.37332f-5 -0.00012483081 -4.0684237f-5 0.0002216954 7.509968f-5 -0.00013336226 0.00012464907 9.7256016f-5 0.00011228887 -0.00019075554 8.628432f-5 7.759798f-5 0.00020596318 2.9669382f-5 -0.00021442384 -3.9654038f-5 -0.00015567531 0.00018913783 -8.2195955f-5 9.485527f-5 -0.00012301945; -4.500545f-5 3.5956637f-6 -5.819044f-5 0.00024066369 7.057828f-5 -2.668541f-6 0.00018198336 -7.401057f-5 6.8512745f-6 0.00020965474 0.00013593683 0.00027906138 -0.00010785144 0.00019427043 9.057679f-5 9.320194f-5 -1.8336015f-5 -1.9989006f-5 0.00010548821 -0.00013083236 -2.8838147f-5 -3.670619f-5 -9.656982f-5 -0.00014254377 -2.6001335f-5 -0.00010295841 -0.00015845815 0.00010200021 1.9629311f-5 0.00011851859 3.8390364f-5 0.00014586515; 0.00016751661 -6.405659f-5 7.769193f-5 8.975841f-6 0.00014901489 -0.00025390115 6.730218f-5 -5.495427f-5 0.00026489573 -1.2885228f-5 7.056727f-5 -0.00010104014 -0.00013085247 0.0001416214 -0.00013230045 8.866751f-5 0.00011130718 7.364876f-5 3.8413422f-5 -2.7115018f-5 9.9668396f-5 8.313853f-5 -2.1967127f-5 6.571331f-5 -7.981163f-6 9.778293f-5 -0.0001303662 -8.971247f-5 6.0051385f-5 -5.979386f-5 3.230473f-5 2.3944089f-5; 0.00019150904 -0.00016667384 0.00020109073 -3.7002286f-5 0.00020823033 -2.9182218f-5 8.028828f-5 2.5291569f-5 0.00035813876 3.2219632f-5 -2.8556156f-5 -6.2879617f-6 -0.00012385703 -4.7432222f-5 -6.448928f-5 5.7140922f-5 -5.552637f-6 3.104605f-5 -4.8373928f-5 -0.00019075842 1.0957801f-5 4.5672838f-5 -1.9162837f-5 -4.8022634f-5 7.439682f-5 -0.00013567439 -1.2075661f-5 0.00010116041 8.84392f-5 0.00014342566 -6.684118f-5 -0.00017913808; 2.5303612f-5 -9.042177f-6 -2.0057405f-5 -5.8437025f-5 -0.00024210024 8.380773f-5 -9.142809f-5 0.00010927875 -0.00016955315 -9.0014575f-5 -9.68634f-5 -0.00012674881 -0.00011367435 0.0001316939 -5.8192072f-5 -0.0001392263 0.0001433972 5.7864396f-5 0.00015347387 -5.7640744f-5 4.6358353f-5 -4.1196254f-5 1.9570263f-5 -4.9212973f-5 7.629284f-5 -7.736663f-5 3.2848908f-5 -0.00024378933 3.307202f-5 -6.9625545f-5 -5.335751f-6 4.7842354f-5; 7.429204f-5 -1.6023338f-5 -3.933759f-5 8.593395f-6 0.0001347484 -4.943457f-5 -0.00015114376 0.00010197669 -1.0414226f-5 -4.232435f-5 -0.00013986156 1.1609421f-5 -8.537325f-5 -6.0289996f-5 -6.191453f-5 -0.00019725274 5.6371705f-6 0.0001454031 -0.00014759955 -2.9437999f-5 -0.00010341728 0.00015686535 -7.399175f-5 3.354107f-5 3.1237946f-6 -3.532328f-5 7.4126074f-5 0.00014947413 4.2570813f-5 -7.920741f-5 -3.188879f-5 5.8826616f-5; 4.5898356f-5 -0.00021049994 -0.00013618497 -5.4310323f-5 -4.590309f-5 -1.532831f-5 4.1352694f-5 -9.043282f-5 8.907289f-6 -8.762086f-5 -3.4954983f-6 0.00013628697 6.372965f-5 -2.160817f-5 -7.5085976f-5 -0.00018109397 5.1842744f-5 5.9965827f-5 0.00011077812 1.5870823f-5 -0.00014709691 0.00010015112 -0.00010941574 -6.7342975f-5 0.00011384917 -1.805893f-5 0.000144156 -2.929432f-6 -1.8218592f-5 -3.4312798f-5 -1.5657042f-5 -7.305215f-5; 0.00019923133 0.0001609701 4.970292f-5 0.00015767611 3.3199194f-5 5.295157f-5 4.189899f-5 0.000119795564 -2.7155882f-5 -1.5478669f-5 -1.9040122f-5 7.071769f-5 -4.9245133f-5 1.662453f-5 -3.2830976f-5 -5.4755856f-5 -3.4263187f-5 -5.7640114f-5 5.5642962f-5 6.909018f-5 -0.00013154361 -6.3458974f-6 1.7948707f-5 -9.331984f-5 6.3150364f-5 -0.00012460348 -0.00020954428 -0.00014148465 6.229265f-6 -0.000117293326 5.833916f-5 3.1278f-5; -0.0001067559 4.580052f-5 5.0685263f-5 3.1908297f-5 -9.131171f-5 7.592512f-5 5.670526f-6 -1.629011f-5 9.918572f-5 -1.330123f-5 -8.495466f-5 -0.00024187574 0.00011190954 -9.600457f-5 6.0324637f-5 -0.00016555433 1.8871959f-5 8.129774f-5 -6.660211f-6 2.1201551f-5 0.00018786445 7.3364936f-5 4.7511083f-5 -7.891183f-5 -9.588667f-5 -0.00010285343 -3.197771f-5 -6.46538f-5 -0.00017598418 -0.0001121735 4.1562496f-5 -2.6662103f-5; 3.562712f-5 -0.00016206814 -0.00014809163 0.0002016972 0.00014650064 -4.8938306f-5 7.822209f-5 0.00010327005 5.4330514f-5 -0.00014183664 -5.528146f-5 0.00012893822 6.3572304f-5 8.5229396f-5 -8.652162f-5 5.246516f-5 0.00015851161 2.6026577f-5 -3.2066648f-5 0.00015432402 9.894136f-5 -1.0106837f-5 0.000119336524 0.00011091653 2.8125694f-5 0.0001510297 -0.00023216972 -0.00013089937 6.542886f-5 8.6616266f-5 0.00018007317 -2.559896f-5; -5.0164606f-5 -3.7097398f-5 0.000113823655 -0.00017891802 -9.902323f-5 4.632575f-5 1.7866767f-5 -0.00010904528 5.5917666f-5 -0.00011505141 -0.000118628464 -7.7514276f-5 5.0476036f-5 -0.00014608978 3.546836f-6 0.00010498919 -7.223308f-5 0.00017503745 7.1214265f-5 -7.0753376f-5 -1.9573323f-5 -1.2963226f-5 -7.476161f-5 0.00016279709 -8.4671745f-5 0.00012089024 1.0373263f-5 0.00013308413 0.0001960544 -9.005978f-6 -6.315092f-5 -1.5508653f-5; 8.501915f-5 0.00011784553 1.086995f-5 -0.00028398313 -9.614382f-5 -2.4861516f-5 0.00013437707 -4.855439f-5 3.4479155f-5 -2.2151764f-5 -1.9358784f-5 7.7908844f-5 4.6163455f-5 -7.051595f-5 -0.00018663336 9.081326f-5 -8.394663f-5 -4.452978f-5 4.2994998f-5 -8.378879f-5 -5.361333f-5 -2.8442537f-5 5.7980484f-5 -8.051073f-5 -2.8665732f-5 6.66307f-5 0.00021859977 -9.809439f-5 0.00019976702 -8.579535f-5 2.1915139f-5 -7.8818244f-5; -0.00014983681 0.00011836864 9.792958f-5 0.00013217321 8.850639f-5 -0.00010567408 -0.000120884535 -2.1363201f-5 4.8748814f-5 -5.9290753f-5 -0.00011701956 -0.00015289572 -2.9265435f-5 -6.6320295f-5 -0.0001346363 0.00011162893 -2.2486865f-5 6.540289f-5 1.0023361f-5 0.00010989781 5.398345f-5 5.2293286f-5 4.0248768f-5 -0.00028875034 -7.2902854f-5 0.000100961704 -1.2767729f-5 -0.00012009277 -2.6314741f-5 0.0001080049 0.00010207643 -8.753533f-5; -7.705559f-6 8.3027844f-5 -2.1274296f-5 -7.327287f-5 0.0002607331 0.00011859058 -0.00018668911 7.511506f-5 -0.0002633629 0.00016362451 0.00016061071 -0.00021387398 4.11373f-5 -0.000109145454 -5.092219f-5 -0.00012310872 3.2011852f-5 -8.308074f-5 -2.189896f-5 -1.8832889f-5 0.00012918616 6.35812f-5 0.0001742774 -0.00018741156 -0.00018348593 -8.082411f-5 2.1807138f-5 -0.000109552646 5.954515f-5 -0.00010478444 -3.1968513f-5 1.33226495f-5; -3.8143164f-5 5.1447467f-5 3.918536f-5 -8.866672f-5 -5.595521f-5 3.958306f-5 0.00010256385 0.000100220605 0.000106843494 3.893567f-5 0.00017383945 -0.00012402602 -9.169229f-5 -0.00011046578 -5.7445253f-5 -0.00015208196 0.00015202144 -0.00031237077 5.7432593f-5 -5.334002f-5 8.0130696f-5 0.0002549844 -4.426911f-5 0.00013838173 4.8268543f-5 -5.0967188f-5 -6.099455f-5 -0.00020200713 0.00023476113 6.279807f-5 -1.0891022f-5 0.00011915383; 9.511108f-5 1.590476f-5 0.00021669657 3.7027487f-5 7.70771f-5 -1.100163f-5 6.331661f-6 6.736898f-5 -2.7317414f-5 5.54119f-6 -3.389802f-5 -7.239503f-5 -9.2884155f-5 3.7709317f-5 -0.00015635855 1.9064664f-5 -1.6506381f-5 3.4989476f-5 -0.00020051587 -3.531614f-5 -7.096179f-5 -0.00013401099 -5.070208f-5 0.00017651837 -3.8281494f-5 5.944739f-5 5.976133f-5 -9.050254f-5 5.571898f-5 1.40531f-5 -4.782073f-6 8.668025f-5; 9.809527f-6 -8.389827f-5 8.335371f-5 -0.00015235825 1.6846783f-5 -4.8843125f-5 -3.90877f-5 4.424726f-5 -4.423497f-5 2.7725475f-5 6.831684f-5 -2.5708017f-5 -0.000112615926 -4.6403773f-5 -4.4880413f-5 0.00013654277 5.7352383f-5 -0.00012454382 9.671555f-5 -9.47112f-5 -0.00012977555 9.2543836f-5 5.4069533f-5 3.258067f-5 -6.58506f-5 -5.3593118f-5 1.7590854f-6 6.32998f-5 3.395042f-5 -7.4046075f-5 0.0002597462 0.00010232807; -0.0001045241 -0.00013657067 -7.847691f-5 -0.00012554336 -2.913408f-5 4.4106655f-5 -2.8391576f-5 -9.0365415f-5 -3.0923373f-5 2.6962211f-6 0.00028797015 0.00024057776 9.248801f-5 -0.000110375215 0.00013936564 -0.00013945096 3.2219552f-5 7.868814f-5 1.9056066f-5 0.00018771789 7.608242f-5 -0.00016925139 -4.8017617f-5 0.0002476875 -0.00023271536 9.024287f-5 -0.00013282671 6.9768874f-5 -1.321868f-5 3.125106f-5 8.8771034f-5 -1.9706953f-5; -0.00012419437 2.529067f-6 -0.00010135445 3.343802f-5 0.00011334412 -1.6700486f-5 -4.0614683f-5 9.448913f-5 -1.4398691f-5 -7.597206f-6 -0.00019119048 1.7266342f-5 0.00010161194 5.122987f-5 0.00017874631 -4.8870566f-5 -0.00010169631 5.687608f-5 -6.132848f-6 2.2245298f-5 4.881135f-7 4.2870306f-5 -8.692944f-5 -0.00013529243 4.3506356f-5 -0.00013215942 0.0001309861 5.2557658f-5 -5.7961453f-5 5.72481f-5 -9.27197f-5 2.8917406f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[-4.392225f-5 8.568757f-5 -3.834954f-5 -0.00014281877 -7.720324f-5 3.076116f-5 0.0001606887 -0.00010985947 -2.821813f-5 0.00018598614 -0.000102904254 -2.4525569f-5 -0.00010475418 -1.8755667f-5 -5.5632732f-5 -7.961128f-5 6.6682944f-5 -0.00031220037 1.985713f-5 -6.2012194f-5 -5.382837f-5 2.989172f-5 0.00016529008 -0.00012945072 0.00015198824 5.6323377f-5 -3.47265f-5 -5.4345357f-5 -5.399769f-5 -9.22417f-5 0.0001359759 9.7391305f-5; 1.9812396f-5 -4.0389146f-5 -2.9355737f-5 -1.4708296f-5 -0.00013533313 6.435192f-5 -0.0001034136 3.9406787f-5 -3.9279825f-5 -0.00013719458 -7.581318f-5 2.544754f-5 0.00010459265 -4.6503163f-5 -9.710853f-6 -8.7734705f-5 5.8577214f-5 -4.9642535f-5 8.980928f-5 -8.717008f-5 -3.111253f-5 -0.00010329253 -9.8770826f-5 -7.214658f-5 -2.8971064f-5 2.3889945f-5 -4.8450707f-5 -6.300908f-5 -0.00017256681 -0.00026989216 0.00013713035 -7.864003f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))
Similar to most deep learning frameworks, Lux defaults to using Float32
. However, in this case we need Float64
const params = ComponentArray(f64(ps))
const nn_model = StatefulLuxLayer(nn, nothing, st)
StatefulLuxLayer{Val{true}()}(
Chain(
layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
layer_2 = Dense(1 => 32, cos), # 64 parameters
layer_3 = Dense(32 => 32, cos), # 1_056 parameters
layer_4 = Dense(32 => 2), # 66 parameters
),
) # Total: 1_186 parameters,
# plus 0 states.
Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses
where,
function ODE_model(u, nn_params, t)
χ, ϕ = u
p, M, e = ode_model_params
# In this example we know that `st` is am empty NamedTuple hence we can safely ignore
# it, however, in general, we should use `st` to store the state of the neural network.
y = 1 .+ nn_model([first(u)], nn_params)
numer = (1 + e * cos(χ))^2
denom = M * (p^(3 / 2))
χ̇ = (numer / denom) * y[1]
ϕ̇ = (numer / denom) * y[2]
return [χ̇, ϕ̇]
end
Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
axislegend(
ax,
[[l1, s1], [l2, s2]],
["Waveform Data", "Waveform Neural Net (Untrained)"];
position=:lb,
)
fig
end
Setting Up for Training the Neural Network
Next, we define the objective (loss) function to be minimized when training the neural differential equations.
const mseloss = MSELoss()
function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
return mseloss(pred_waveform, waveform)
end
Warmup the loss function
loss(params)
0.0006763513696001632
Now let us define a callback function to store the loss over time
const losses = Float64[]
function callback(θ, l)
push!(losses, l)
@printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
return false
end
Training the Neural Network
Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
optprob,
BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
callback,
maxiters=1000,
)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-8.938806422514726e-5; 4.5351503104044775e-6; -0.00015214136510622672; 6.114775896985776e-5; -1.1209133845097925e-5; 8.799053830450136e-6; 0.00011623871978354958; 3.05428984574866e-5; -3.83481983589441e-6; 0.0001513870083725434; 2.1730368189271052e-6; -5.228177178644056e-5; 9.921481978383298e-5; 4.351030656831392e-5; -0.00012580592010630502; 3.849606400760113e-6; -0.00010854099673450575; 7.900058699304247e-5; 0.00011266743240410912; 6.33920499239676e-5; -0.00012326815340187242; -6.280487286859402e-5; 0.00015057165001046115; -0.00013831796240977562; 4.2652667616475365e-5; 7.461640780086796e-5; -4.279459972173437e-5; 1.377399439660487e-5; -5.564968159880926e-5; 9.758317901273924e-6; 7.780476153124514e-5; -3.173880759276792e-5;;], bias = [-6.645422995410271e-17, -1.7148753188157285e-18, 9.995005268341798e-18, 2.095635098809355e-16, -1.8702122895345057e-17, -7.454012346947663e-18, 2.0413931215277323e-16, 1.254471120737414e-17, -8.625199888480758e-18, 1.7817547953730006e-16, 4.552359521262397e-18, -8.257994645939841e-17, -2.769947097597213e-17, 4.2917175044768115e-17, -6.283253036393875e-17, 3.6540933651398515e-18, -1.3120278061962626e-17, 2.1130124738711066e-17, 1.5574828705355282e-17, 4.3499697976688064e-17, -1.7462245528652635e-16, 1.2863340124472934e-17, 8.458722649158442e-17, -1.786156618515693e-16, 2.1174986221436126e-17, 7.460398350081226e-17, 7.077152589877468e-17, 1.2383244332910897e-17, -8.254649361598538e-17, 1.5514385475273563e-17, 6.071339246078305e-17, -7.068233324686493e-18]), layer_3 = (weight = [0.00011451291297242465 -9.79735946777199e-6 -1.8448262068268887e-5 9.298093877412865e-5 -0.00010174245167610627 3.2807135852347905e-5 -0.00020098294103838165 -6.699439896357648e-5 0.00011111808211688494 0.0001580636956929807 -0.0002444202188201566 9.813042501599097e-5 3.0466020092427907e-5 0.00017377761045616634 7.617367313621863e-5 -0.00011139983749336953 -4.0784344179263473e-5 0.00014837821740337165 0.00012752621054794988 -0.00010261884804669287 7.162646883778581e-5 -6.206239109480272e-5 0.00013178561434295223 3.7602952431190124e-6 -0.00023320661480858405 6.102827623515848e-5 -1.6187360103064656e-5 3.89166519595923e-5 -5.8271424364936236e-5 -0.0001943318863198479 7.452286749292614e-5 3.6655779098221735e-5; -5.309100993377336e-5 2.5813874569807936e-5 -1.0875851564727414e-7 0.00010407514969069815 -0.00014479049673586373 -4.946119111313698e-5 7.447813299899116e-5 -5.19135126092316e-5 0.0001884470377694515 0.00010891629540076471 0.00010272328404194472 -8.094727923212157e-5 0.0001269726465101874 -6.253205606530839e-5 5.767289007729713e-5 -4.494748966449133e-6 1.4978350234232803e-5 7.519458199335281e-5 -0.0001632898081448373 1.2470258523035252e-5 0.00016494866326632846 0.00011111565004052872 1.150374577939761e-5 -0.00020546907285512938 0.0001821783055197127 -6.843074766270721e-5 0.00018747160833977684 0.00019399939378441082 -4.503695049838718e-5 2.401159622191237e-5 0.00018856022256609096 -0.00016264160763289164; -9.604515788426788e-5 -7.176511270274789e-5 -0.00010758168344668703 3.9344304345531695e-5 0.00011421445389230345 9.24812121544833e-5 -1.448363117136952e-5 6.491658870315645e-5 8.991029618725731e-5 0.0002608620926499962 0.0002153316819337065 3.567772011108072e-5 3.464397101778511e-7 5.9594465814981706e-5 -8.501611581354334e-5 -0.00010264054425108628 -0.00011332282145789396 0.0002818518042792379 0.00017741992643519652 4.246591935468099e-5 1.3806347955922894e-5 -5.606358889538596e-5 0.00012256319324901714 -6.015944735313063e-5 -0.0001738461650214306 -9.481065343182481e-5 1.5514770088586176e-5 -8.694907581714379e-5 4.920017101055285e-5 -1.2731725187519727e-5 4.589359752086932e-5 4.614421424291558e-5; 0.00019276555856013688 -8.068370782747763e-5 4.276951319443858e-5 -0.00011042737144488779 -6.7340537756473375e-6 8.257324686529216e-5 -6.072129592975637e-5 0.00012129445992432625 -4.9118235372946495e-5 -5.8978256609767144e-5 -6.647868843028219e-5 6.327677593677279e-7 -5.2841655750302574e-5 8.49030157693111e-5 7.28775094754429e-5 3.4884233824476635e-6 -9.222448254913111e-5 -0.00012040983977239544 2.0756774639817612e-5 -3.0037001968952213e-5 -6.746777483239826e-5 6.506411325208903e-5 -0.00010362334699347746 -0.0001370192866746904 0.00016208296171644785 0.00013595256775407165 -0.00011376773448172874 9.020096792629089e-5 0.00012746856834529037 7.169671614211154e-5 -6.199766624230854e-5 1.8654457627787246e-5; 3.0185877061300014e-5 -8.257041479417902e-5 1.630341905504141e-5 0.00018114725828446745 -1.31835824290531e-5 -9.103074551284159e-5 -2.0344251585339754e-5 -0.00023838454506084136 2.2423417998209256e-5 0.00012497468271569485 0.00013387527435362488 4.847424936407506e-5 5.737618705417572e-5 -5.230525419072952e-5 6.600563649364596e-5 -0.00018184339858315737 -5.8644763318154354e-5 -7.334324179872471e-5 9.3702308945633e-6 -0.00015086391356497692 -9.431442156390463e-5 -3.5022013595594734e-5 1.6102697586781035e-6 0.0001435784768759923 4.5269473797373855e-5 -6.522138135231658e-5 -8.797082698628082e-5 -0.00014924153508808307 -0.00013493314616106177 0.0002091896429406122 0.00015619795596968778 -2.533406148039545e-5; 1.2430814823996009e-6 -5.8055906566401625e-5 1.0309197631781617e-5 -8.571347673826708e-5 -9.891036855227457e-5 -7.807078175650127e-5 5.1162350854771135e-5 0.00013933173219735654 8.684484388974343e-5 -0.00014101529480359707 3.597819128950902e-5 -1.5625341194057206e-5 -0.00027564526347380433 -3.5024226536583195e-5 0.0001715611976298335 5.614637090653957e-5 -0.00012805255058158336 -0.00013819750559453505 3.478185287274348e-5 -5.500988511888144e-5 -2.1967633375263436e-5 1.5276716721171247e-5 0.00015468394455564003 -1.4613186368313867e-5 -6.49442993843277e-5 -3.4326629548405204e-5 5.889843275250858e-5 2.451052538996422e-5 4.187459096025767e-5 -6.74545775207985e-5 8.099666926606886e-5 -0.0002646098476642099; -0.00012759851882210066 -6.176191448676587e-5 -0.00011273812481874516 -0.0002963359212407346 -3.0293141327916536e-5 3.1947610175933744e-5 -2.0876185607325316e-6 -0.00016667695957212774 -4.2438022439369965e-5 0.00014505207297418318 -0.0001420673862702698 4.714672183408573e-5 -0.00011947235100878222 -3.589543042099044e-5 -8.673352488757066e-5 0.0001638828042524382 6.797733009817518e-5 0.00014028180964315013 -7.638686958569094e-6 -5.180420550368811e-6 -0.0002180971561136939 -2.838070680072666e-5 -4.0236881552003134e-5 -8.091912525913197e-5 7.670438941172884e-7 -0.00015540912057249263 -4.980346267896895e-6 2.510861201525965e-5 -3.501913591380944e-5 -0.00011382529112952941 0.00013188685340365076 0.00016228964602132273; -8.755074067571924e-5 5.5647493415053166e-6 -6.364197305934028e-5 -0.0001072207827686319 -0.00010998158667767351 7.969677117820897e-5 4.443539965600081e-5 -0.00020811374396849345 -9.552289829421722e-5 -0.0001234268214769614 -1.639652818390102e-5 0.00012153080625161642 2.4608373320594038e-5 4.835001400956053e-5 0.00018568413234483798 2.5120393371838812e-5 4.3319034927341106e-5 -5.599139093952077e-5 3.242782728615537e-7 7.138268720759216e-5 -4.3584271582724904e-5 -9.787642319984599e-5 -0.00012688763071614777 -8.171607022936982e-5 3.588342461320671e-5 9.467158689098598e-5 0.0001499028354434306 9.167627886803891e-5 -6.360301030631988e-5 0.0001005452710220597 7.2269232138252e-6 1.3445643714152656e-5; -5.395627927781932e-5 8.095125626142483e-5 0.00011621917614548049 6.970443022738437e-5 -0.0001352702970120742 -4.480066725126281e-6 -4.631533052587981e-5 0.00013344688299413157 -0.0001128375083294383 9.212818217065707e-5 5.306041888844568e-5 9.320104747994579e-6 -0.000123932274125302 4.784113438519001e-5 -1.2683326037853855e-5 0.00020009456742856068 -0.00015104707604904704 2.746745625928185e-5 4.139634215108047e-5 7.65472355328801e-5 -8.911259688601983e-5 -9.813307447756424e-5 3.321971557170681e-5 0.00015736471645368905 0.00018051124019849405 -2.7750246279075045e-5 7.835474347531151e-5 -4.9584974892685505e-5 -0.00011651051359648569 -4.717428278808138e-5 -0.00011173655407900773 3.822370085118652e-5; -1.5021672432960509e-5 -3.5323102118744105e-5 -0.00010319478452282968 0.0001809391019299436 8.16486703946082e-5 -0.0001637856999047274 -5.9145405892818505e-5 0.0001647995436463328 4.9232252869225524e-5 -4.7986680791554984e-5 1.1803257158427928e-5 -8.611686053762204e-5 -0.00012030381322682224 -0.00013007986991276916 0.00018071618114055463 -1.7264881088229404e-5 -2.0567902162694882e-6 -8.637732670359926e-6 5.4481386988586855e-5 -5.760283739498867e-5 4.5774361289035404e-5 -0.00010514254382422242 -2.2330561031312568e-5 6.609705314195892e-5 6.338377578806181e-5 6.413950767757248e-5 -9.79945703044715e-5 0.0002697438817370509 -8.859841319343083e-5 0.0001309472340756854 0.00016414149148778843 -1.8273050508176952e-5; -0.00011095647059065096 -7.31086562076633e-5 -7.603948469423959e-5 7.81933835320114e-5 0.0002528298922776195 9.941000288196086e-6 -4.46356952253006e-6 -0.0001289280931736394 0.00018111408135937037 4.1298141157847973e-5 -5.3252276840252136e-5 8.52240774165263e-6 2.2520061143885575e-6 -8.019608648410617e-5 0.000144401898440802 -0.00011091396444626647 -2.9316342617137513e-5 7.742617898287867e-6 -2.9622973298872326e-5 2.0569216220423943e-5 8.565044700716594e-5 -0.00014406419503885001 2.7459811834513435e-5 0.00014535609662615542 -0.00035004600212360345 4.959499566084042e-5 6.684296968941959e-5 -0.00018068490461828682 0.00023588425975376156 0.00013360089943361908 0.00010278561095091769 1.6598344162590228e-5; 4.8942495965333134e-5 -0.00010558547418781407 -4.349637728748443e-5 -6.140097011097217e-5 4.505040433855888e-5 -1.0909043879786813e-5 -0.0001111807510767273 6.336175957245054e-5 0.00012667033243474703 3.7345563726182522e-6 -1.3194348582549364e-6 5.469620319332394e-5 0.00019611976986767842 6.496217193330958e-5 -8.578153264691389e-5 2.0735441836455474e-5 -3.0051255265545895e-5 4.269593887438366e-7 -3.3937798974631305e-5 -0.00013561235445535226 4.7011544126026796e-5 2.0140408565802935e-5 -0.00019741611376386712 -3.928556233743172e-5 -0.00011197174225684872 -0.00011016389234440675 -0.00011341663102370371 -3.8852912069836824e-5 -0.00019757670870034521 5.569859368393355e-5 2.7644156404004926e-5 -5.343870051887577e-5; -0.00017562036238333238 -0.00013483525846341883 9.610672849275579e-5 2.1921006450875084e-5 1.2460904586340545e-5 -0.00015480019511843359 9.201858618807178e-5 6.240755600200733e-5 0.00011080360990062098 0.00014819380292946873 -4.990981337606466e-6 -1.2824131331135926e-5 7.108220005219826e-5 7.439559839011897e-5 0.00016061376071347118 -7.176343070997184e-5 1.6462823155318257e-5 -6.255777272118335e-5 -0.00011654346464157 -3.157785842150592e-5 -0.0001348716382514854 -5.190775347246933e-5 -8.762537483080827e-5 -7.814710001826168e-6 -2.1744692864656363e-6 -1.582571032004995e-5 5.0755233574172276e-5 1.5317270922686228e-5 -2.382902540804057e-5 -2.351536069428339e-5 7.579429937804033e-5 -0.00012039385047929591; 8.904566821942905e-5 -5.5626104212616366e-5 -9.62715315742199e-5 0.00021072437412871978 1.6358791627088746e-5 -9.354694010621757e-5 -8.977684446077168e-5 -5.723140601699155e-5 -7.031475788290868e-6 0.00011867987960644402 -7.987442621893135e-5 7.373414301614306e-5 -0.0001248298689366989 -4.06832954373857e-5 0.0002216963435879219 7.510061878779086e-5 -0.00013336132214104076 0.0001246500138316575 9.725695746087188e-5 0.00011228981348753898 -0.0001907545996706689 8.628526448805809e-5 7.75989206980203e-5 0.00020596412251538116 2.9670323429261918e-5 -0.00021442289688789346 -3.965309623677226e-5 -0.00015567437038921644 0.00018913877612280677 -8.219501286444144e-5 9.485621159032844e-5 -0.00012301850473673003; -4.5001595223441036e-5 3.5995188076129407e-6 -5.8186583570808174e-5 0.0002406675473897518 7.058213557990594e-5 -2.6646857501248165e-6 0.00018198721116582413 -7.400671276321414e-5 6.85512967117469e-6 0.00020965859474321113 0.00013594068767212674 0.00027906523499347254 -0.00010784758452850376 0.000194274280859978 9.05806435033325e-5 9.320579445863481e-5 -1.8332160122233862e-5 -1.9985150371794473e-5 0.00010549206379742016 -0.0001308285002020391 -2.8834291894465798e-5 -3.670233601645728e-5 -9.656596486516981e-5 -0.00014253991068167846 -2.599747969419121e-5 -0.00010295455396466325 -0.0001584542926290248 0.00010200406430464556 1.9633166433199697e-5 0.00011852244532144026 3.839421950338908e-5 0.00014586900832306806; 0.00016751909687358127 -6.405410850944104e-5 7.769441698385267e-5 8.978325137869073e-6 0.00014901736984670043 -0.00025389866378914676 6.730466237698401e-5 -5.4951787308670525e-5 0.0002648982143792127 -1.2882744261311616e-5 7.056975563231752e-5 -0.00010103765554318532 -0.00013084998237260597 0.0001416238909651313 -0.0001322979706974095 8.86699953180541e-5 0.0001113096613035244 7.365124009724092e-5 3.841590558934744e-5 -2.711253465618906e-5 9.967087943280628e-5 8.314101695490575e-5 -2.196464366804907e-5 6.571579696085497e-5 -7.97867970773738e-6 9.778541688050516e-5 -0.0001303637155733283 -8.970998634409279e-5 6.005386866221778e-5 -5.979137596382872e-5 3.230721530329767e-5 2.394657247991114e-5; 0.000191510786015093 -0.0001666720926840561 0.00020109248004528726 -3.7000539970796285e-5 0.00020823207166187883 -2.9180471694306883e-5 8.029002498743709e-5 2.5293315109702786e-5 0.0003581405035839646 3.2221378267519604e-5 -2.8554410102405023e-5 -6.286215404434674e-6 -0.0001238552794637917 -4.7430476053596064e-5 -6.448753259236475e-5 5.714266814549881e-5 -5.550890757045869e-6 3.1047795408191856e-5 -4.837218141973321e-5 -0.00019075667654022376 1.0959547642516177e-5 4.5674583878388784e-5 -1.9161090614846344e-5 -4.8020887272187706e-5 7.439856569191764e-5 -0.0001356726424381871 -1.2073914996908537e-5 0.00010116215714162634 8.844094560856248e-5 0.00014342740184303794 -6.68394349594406e-5 -0.00017913633194610105; 2.5300624557715958e-5 -9.045165014690633e-6 -2.0060392770373863e-5 -5.844001235236732e-5 -0.00024210323054722258 8.380474436416538e-5 -9.143107773415088e-5 0.00010927576194371429 -0.00016955614147155307 -9.001756289651887e-5 -9.686638541898449e-5 -0.00012675179924263107 -0.00011367733996739323 0.0001316909137050944 -5.819505996331617e-5 -0.00013922928074753605 0.0001433942187679421 5.7861407996098746e-5 0.0001534708816427236 -5.7643731551070816e-5 4.635536505812671e-5 -4.119924152182714e-5 1.9567275459139464e-5 -4.921596074271663e-5 7.628985157789154e-5 -7.736962025664881e-5 3.284591993287627e-5 -0.00024379231500352478 3.306903353513767e-5 -6.962853250006896e-5 -5.338739000579526e-6 4.783936573512261e-5; 7.429064958476523e-5 -1.602472634342338e-5 -3.9338977282123515e-5 8.59200650072988e-6 0.00013474701795846854 -4.943595826855232e-5 -0.00015114514730750267 0.00010197530450884922 -1.0415615130589136e-5 -4.2325739692851884e-5 -0.00013986294556754948 1.1608031903487223e-5 -8.537463777112629e-5 -6.029138507667216e-5 -6.191591723701497e-5 -0.00019725412762138368 5.635781738859389e-6 0.00014540171200995317 -0.00014760094104214293 -2.943938756881869e-5 -0.00010341867068429088 0.00015686396459703208 -7.399314103760023e-5 3.353968082000142e-5 3.1224058634311547e-6 -3.532466775734029e-5 7.412468499158286e-5 0.00014947274126232303 4.2569424343563846e-5 -7.920880211397376e-5 -3.189017929482302e-5 5.8825226963998024e-5; 4.5896336938923637e-5 -0.00021050196288645943 -0.00013618699073351058 -5.43123422751991e-5 -4.5905107857228426e-5 -1.5330329791444547e-5 4.13506751435086e-5 -9.043483793296535e-5 8.905270202982004e-6 -8.762287676592441e-5 -3.49751723054774e-6 0.0001362849471994812 6.372762812741078e-5 -2.1610188799361614e-5 -7.508799523140859e-5 -0.00018109599101383915 5.1840725188116695e-5 5.9963807995400576e-5 0.00011077610191630364 1.5868803654183275e-5 -0.00014709893354240472 0.00010014910015949565 -0.00010941776041665214 -6.734499389796852e-5 0.00011384715265800621 -1.806094939031944e-5 0.0001441539826707114 -2.9314509113956026e-6 -1.8220611185217327e-5 -3.4314816606032845e-5 -1.565906119442819e-5 -7.305416862739858e-5; 0.0001992316853538137 0.00016097045105919152 4.970327305993635e-5 0.00015767646497618276 3.31995458217794e-5 5.295192265105341e-5 4.189934285743843e-5 0.00011979591605989161 -2.7155530308001668e-5 -1.5478316829170887e-5 -1.9039770752052766e-5 7.071803944920997e-5 -4.92447810139062e-5 1.6624881634860087e-5 -3.283062445195668e-5 -5.475550399939621e-5 -3.4262834862668484e-5 -5.763976271925658e-5 5.564331392445362e-5 6.909053142219644e-5 -0.00013154325764123617 -6.345545745782071e-6 1.7949058626914572e-5 -9.331948725236894e-5 6.315071611236667e-5 -0.00012460312912669186 -0.00020954392433129363 -0.00014148429666608665 6.229616854332734e-6 -0.00011729297451172116 5.833951096786024e-5 3.127835104162701e-5; -0.0001067581271275153 4.5798291342818744e-5 5.068303459094813e-5 3.1906069001438585e-5 -9.131393938905401e-5 7.592289135909795e-5 5.6682978208418395e-6 -1.6292338965556276e-5 9.918349484332986e-5 -1.330345834165153e-5 -8.495689156579636e-5 -0.00024187796419600607 0.00011190731116922508 -9.600680108296398e-5 6.032240877513217e-5 -0.00016555655941947212 1.886973028714381e-5 8.129550939311878e-5 -6.662439442414045e-6 2.119932274918514e-5 0.0001878622259607385 7.336270742935565e-5 4.7508854426014587e-5 -7.891405597209315e-5 -9.58889014657838e-5 -0.00010285565998519217 -3.197993840903693e-5 -6.465603122369222e-5 -0.00017598641182875818 -0.00011217572513150825 4.156026794926855e-5 -2.6664331140907122e-5; 3.563142547570815e-5 -0.000162063837137987 -0.00014808731918969747 0.00020170150548872847 0.00014650494142623503 -4.893399943758841e-5 7.822639474831392e-5 0.00010327435904836053 5.433482021979889e-5 -0.00014183233849985116 -5.5277153819781754e-5 0.00012894252597542 6.357661061408383e-5 8.523370217952515e-5 -8.651731457142614e-5 5.246946477635665e-5 0.00015851591763272997 2.6030883077308137e-5 -3.2062341574812454e-5 0.00015432832753938894 9.894566630897677e-5 -1.0102531064347799e-5 0.00011934083033770772 0.00011092083752806311 2.813000048698507e-5 0.0001510340067623282 -0.0002321654186518712 -0.00013089506080651777 6.543316578680118e-5 8.662057246036722e-5 0.00018007747587469396 -2.5594653721067627e-5; -5.016496159172566e-5 -3.7097752808258365e-5 0.00011382329960084752 -0.00017891837940354826 -9.902358449566219e-5 4.632539405334698e-5 1.7866411711076504e-5 -0.00010904563403259064 5.5917310803997834e-5 -0.00011505176242011372 -0.00011862881963410457 -7.75146314272565e-5 5.047568031201212e-5 -0.0001460901317632666 3.5464808415437936e-6 0.00010498883186570327 -7.223343580066132e-5 0.00017503709516218896 7.121390938024142e-5 -7.075373157662336e-5 -1.957367801609313e-5 -1.2963581481962428e-5 -7.476196203882582e-5 0.00016279673711593447 -8.467210007169306e-5 0.00012088988425872312 1.037290751170113e-5 0.00013308377803965434 0.00019605404797672289 -9.006333654760802e-6 -6.315127264046082e-5 -1.550900795067691e-5; 8.501828688424594e-5 0.00011784466460341877 1.0869086511619495e-5 -0.0002839839917189029 -9.614468308170964e-5 -2.4862379175081734e-5 0.0001343762078275412 -4.855525275814033e-5 3.4478292039076886e-5 -2.215262702263422e-5 -1.9359647010040375e-5 7.790798099706318e-5 4.6162592300629315e-5 -7.051681040791451e-5 -0.00018663422542629032 9.081239386689207e-5 -8.394749329197532e-5 -4.453064411534462e-5 4.299413467636047e-5 -8.378965594345302e-5 -5.3614193345468764e-5 -2.8443399966979125e-5 5.797962058098053e-5 -8.051158965570023e-5 -2.8666595423763015e-5 6.662983924145326e-5 0.0002185989096713659 -9.809525090275039e-5 0.0001997661556768641 -8.5796212258336e-5 2.1914275751281473e-5 -7.881910732719728e-5; -0.00014983818876117806 0.00011836725859410361 9.792819871308634e-5 0.00013217183111208892 8.850501141949794e-5 -0.00010567546162548241 -0.0001208859165015453 -2.1364582451327783e-5 4.874743278577222e-5 -5.929213405379685e-5 -0.00011702094236881242 -0.00015289710320558973 -2.926681629815859e-5 -6.632167681147032e-5 -0.0001346376801191641 0.0001116275463320431 -2.2488246784458668e-5 6.540151041207416e-5 1.0021979662444128e-5 0.00010989642866860556 5.398206955960009e-5 5.229190462358705e-5 4.024738625248293e-5 -0.00028875171740565286 -7.290423566502209e-5 0.00010096032308155222 -1.2769109914928341e-5 -0.00012009415406993352 -2.63161226283163e-5 0.00010800351553569782 0.00010207505071974863 -8.753671029609355e-5; -7.707424269418562e-6 8.30259790705586e-5 -2.1276161173244338e-5 -7.327473813750796e-5 0.0002607312320793117 0.00011858871348119102 -0.0001866909757194889 7.511319326619429e-5 -0.000263364755123288 0.00016362264922971137 0.0001606088457224959 -0.00021387584511506444 4.113543527234706e-5 -0.00010914731920314152 -5.0924054998561914e-5 -0.00012311058763264922 3.2009987440300314e-5 -8.308260531014625e-5 -2.1900825781290655e-5 -1.8834753613583724e-5 0.00012918429037706686 6.357933798023495e-5 0.0001742755388437047 -0.00018741342010291292 -0.00018348779358751268 -8.082597433206528e-5 2.1805273168472186e-5 -0.0001095545108950627 5.954328425611315e-5 -0.00010478630666005662 -3.1970378226496876e-5 1.3320784451540597e-5; -3.8141802516637446e-5 5.144882800652808e-5 3.9186722122952906e-5 -8.866535733557312e-5 -5.5953848797230335e-5 3.958442232817572e-5 0.00010256520966196923 0.00010022196568445424 0.00010684485517106182 3.8937029447528275e-5 0.0001738408105926512 -0.00012402465597131246 -9.169092606623934e-5 -0.00011046441738597041 -5.7443892157046006e-5 -0.00015208060301245024 0.00015202280394862038 -0.00031236941083205046 5.743395434239473e-5 -5.333865862792375e-5 8.013205718145187e-5 0.0002549857532618473 -4.426774773618129e-5 0.00013838309007825565 4.826990391722963e-5 -5.0965826968765755e-5 -6.0993187954750956e-5 -0.00020200576488730343 0.00023476249100769895 6.279943102401842e-5 -1.088966052646158e-5 0.00011915519337765061; 9.511119461692712e-5 1.5904876284346567e-5 0.00021669668823322063 3.702760333269691e-5 7.70772170771832e-5 -1.1001513109497149e-5 6.331777888569575e-6 6.736909600747619e-5 -2.7317296905514202e-5 5.541306939427906e-6 -3.389790398343596e-5 -7.239491204656799e-5 -9.288403779193924e-5 3.7709433320722106e-5 -0.00015635843700422168 1.9064781020440744e-5 -1.6506264352776913e-5 3.498959305304881e-5 -0.00020051575810341345 -3.531602450222853e-5 -7.096167208785377e-5 -0.00013401087161312025 -5.0701964562015906e-5 0.00017651848648981918 -3.828137740767698e-5 5.944750629439799e-5 5.976144567555611e-5 -9.050242041396106e-5 5.571909824809849e-5 1.405321692161832e-5 -4.781956328790754e-6 8.668036790639032e-5; 9.809687529400833e-6 -8.389810636488083e-5 8.335387292702268e-5 -0.00015235808632185064 1.684694371163086e-5 -4.88429645852867e-5 -3.9087538877151825e-5 4.424741974019149e-5 -4.4234809589919765e-5 2.772563539422169e-5 6.831700057175649e-5 -2.57078568179843e-5 -0.00011261576537124226 -4.6403612483640434e-5 -4.488025251391719e-5 0.00013654293480015064 5.735254353878664e-5 -0.000124543657417757 9.671571070014141e-5 -9.471103794255607e-5 -0.0001297753910844904 9.25439967463041e-5 5.406969330890623e-5 3.25808291594738e-5 -6.585044274886441e-5 -5.3592957098384824e-5 1.7592459502573475e-6 6.329995858634529e-5 3.395057911204796e-5 -7.40459140584554e-5 0.00025974634673132235 0.0001023282316086132; -0.00010452314491862485 -0.00013656971460453973 -7.8475951531572e-5 -0.00012554240421172155 -2.9133124562341653e-5 4.410761065651856e-5 -2.8390620363776067e-5 -9.036445881951431e-5 -3.0922417144885196e-5 2.697176833181283e-6 0.0002879711053449308 0.0002405787188674893 9.248896348276993e-5 -0.00011037425902917842 0.00013936659778687798 -0.00013945000428952083 3.2220507628385094e-5 7.868909850936384e-5 1.9057021554776295e-5 0.00018771884314470062 7.608337436090875e-5 -0.00016925043492823768 -4.8016661103019085e-5 0.00024768846450592914 -0.00023271440769325808 9.024382851772034e-5 -0.00013282575419884041 6.976982948823091e-5 -1.3217723980908224e-5 3.125201582430124e-5 8.877198960386223e-5 -1.9705997337715915e-5; -0.0001241948896773613 2.528545509796662e-6 -0.00010135497478642122 3.343749875910627e-5 0.00011334359686714583 -1.6701006925921108e-5 -4.061520422566467e-5 9.448860754617489e-5 -1.4399212259939103e-5 -7.597727572477634e-6 -0.0001911910051700195 1.7265820760606162e-5 0.00010161142080449651 5.1229346981365386e-5 0.00017874579015313906 -4.8871087811328636e-5 -0.00010169682837896797 5.68755592115698e-5 -6.133369252543524e-6 2.2244776745886133e-5 4.875921052963081e-7 4.2869784637680776e-5 -8.692995953480341e-5 -0.00013529295148201906 4.35058343003971e-5 -0.00013215993868526924 0.00013098557491446862 5.255713648637743e-5 -5.7961974844708676e-5 5.724757892438195e-5 -9.272022481183433e-5 2.8916884447773926e-5], bias = [4.723403284684363e-10, 3.460852399292578e-9, 2.714036485105339e-9, 3.1020281939330895e-10, -1.2106979623247932e-9, -2.41955679996407e-9, -4.066470779534918e-9, -5.32412155797365e-10, 1.0036660672435531e-9, 1.4319915151801225e-9, 9.298752012821727e-10, -2.7654039006717447e-9, -8.644880939065156e-10, 9.418821253166126e-10, 3.855151015941593e-9, 2.4837272976319246e-9, 1.746291783268293e-9, -2.9877890529382595e-9, -1.3887497720154599e-9, -2.0189539410751964e-9, 3.5167243239324926e-10, -2.2283859815509792e-9, 4.3061161590520885e-9, -3.552277483588167e-10, -8.631079056308144e-10, -1.3813658501353767e-9, -1.8650138681995682e-9, 1.3611758051046515e-9, 1.1679092293677596e-10, 1.6052893069909687e-10, 9.556881730530574e-10, -5.21422969340891e-10]), layer_4 = (weight = [-0.0007244798230098076 -0.0005948697543728443 -0.0007189069543768122 -0.0008233763442447945 -0.0007577607896755383 -0.0006497962913254424 -0.0005198685453515985 -0.0007904170450485005 -0.0007087756862908044 -0.000494571400282072 -0.0007834618132387952 -0.0007050829774530013 -0.0007853117422120672 -0.0006993132261101605 -0.0007361899787943619 -0.000760168721013117 -0.000613874568895234 -0.0009927577303257492 -0.0006607004053762422 -0.0007425696819166549 -0.0007343859469451837 -0.000650665751953791 -0.0005152671162457509 -0.0008100082994271134 -0.0005285693277735826 -0.0006242341602720748 -0.0007152840031366606 -0.0007349028945250861 -0.0007345552682097741 -0.0007727992800480504 -0.0005445816546340689 -0.0005831662680903435; 0.00029118902908376726 0.00023098738770980568 0.0002420208332080712 0.0002566683384140215 0.0001360434902128218 0.0003357285072700591 0.00016796289994481405 0.00031078341880577334 0.00023209680133548042 0.00013418203472274676 0.0001955634479924839 0.0002968241072694962 0.0003759692754064416 0.000224873464532289 0.00026166564975656784 0.0001836418755221586 0.00032995382274483 0.000221734013967972 0.000361185894641462 0.00018420651585440516 0.0002402641042255142 0.00016808406424769404 0.00017260565775094853 0.0001992300570082269 0.00024240556447841866 0.00029526656320476486 0.00022292589712126075 0.00020836753864889027 9.880982307153661e-5 1.4844747875164454e-6 0.00040850697862113085 0.0001927366056785799], bias = [-0.0006805575785330337, 0.00027137663486038446]))
Visualizing the Results
Let us now plot the loss over time
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")
lines!(ax, losses; linewidth=4, alpha=0.75)
scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)
fig
end
Finally let us visualize the results
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
s3 = scatter!(
ax,
tsteps,
waveform_nn_trained;
marker=:circle,
alpha=0.5,
strokewidth=2,
markersize=12,
)
axislegend(
ax,
[[l1, s1], [l2, s2], [l3, s3]],
["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
position=:lb,
)
fig
end
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 128 × AMD EPYC 7502 32-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 128 default, 0 interactive, 64 GC (on 128 virtual cores)
Environment:
JULIA_CPU_THREADS = 128
JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 128
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.