Training a Neural ODE to Model Gravitational Waveforms
This code is adapted from Astroinformatics/ScientificMachineLearning
The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl
Package Imports
using Lux,
ComponentArrays,
LineSearches,
OrdinaryDiffEqLowOrderRK,
Optimization,
OptimizationOptimJL,
Printf,
Random,
SciMLSensitivity
using CairoMakieDefine some Utility Functions
Tip
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector
function one2two(path, m₁, m₂)
M = m₁ + m₂
r₁ = m₂ / M .* path
r₂ = -m₁ / M .* path
return r₁, r₂
endNext we define a function to perform the change of variables:
@views function soln2orbit(soln, model_params=nothing)
@assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
if size(soln, 1) == 2
χ = soln[1, :]
ϕ = soln[2, :]
@assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
p, M, e = model_params
else
χ = soln[1, :]
ϕ = soln[2, :]
p = soln[3, :]
e = soln[4, :]
end
r = p ./ (1 .+ e .* cos.(χ))
x = r .* cos.(ϕ)
y = r .* sin.(ϕ)
orbit = vcat(x', y')
return orbit
endThis function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d_dt(v::AbstractVector, dt)
a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
b = (v[3:end] .- v[1:(end - 2)]) / 2
c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
return [a; b; c] / dt
endThis function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d2_dt2(v::AbstractVector, dt)
a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
return [a; b; c] / (dt^2)
endNow we define a function to compute the trace-free moment tensor from the orbit
function orbit2tensor(orbit, component, mass=1)
x = orbit[1, :]
y = orbit[2, :]
Ixx = x .^ 2
Iyy = y .^ 2
Ixy = x .* y
trace = Ixx .+ Iyy
if component[1] == 1 && component[2] == 1
tmp = Ixx .- trace ./ 3
elseif component[1] == 2 && component[2] == 2
tmp = Iyy .- trace ./ 3
else
tmp = Ixy
end
return mass .* tmp
end
function h_22_quadrupole_components(dt, orbit, component, mass=1)
mtensor = orbit2tensor(orbit, component, mass)
mtensor_ddot = d2_dt2(mtensor, dt)
return 2 * mtensor_ddot
end
function h_22_quadrupole(dt, orbit, mass=1)
h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
return h11, h12, h22
end
function h_22_strain_one_body(dt::T, orbit) where {T}
h11, h12, h22 = h_22_quadrupole(dt, orbit)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
h11 = h11_1 + h11_2
h12 = h12_1 + h12_2
h22 = h22_1 + h22_2
return h11, h12, h22
end
function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
# compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2
@assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"
h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
@assert mass_ratio ≤ 1 "mass_ratio must be <= 1"
@assert mass_ratio ≥ 0 "mass_ratio must be non-negative"
orbit = soln2orbit(soln, model_params)
if mass_ratio > 0
m₂ = inv(T(1) + mass_ratio)
m₁ = mass_ratio * m₂
orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
return waveform
endSimulating the True Model
RelativisticOrbitModel defines system of odes which describes motion of point like particle in Schwarzschild background, uses
where,
function RelativisticOrbitModel(u, (p, M, e), t)
χ, ϕ = u
numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
denom = sqrt((p - 2)^2 - 4 * e^2)
χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = numer / (M * (p^(3 / 2)) * denom)
return [χ̇, ϕ̇]
end
mass_ratio = 0.0 # test particle
u0 = Float64[π, 0.0] # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4) # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize) # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, eLet's simulate the true model and plot the results using OrdinaryDiffEq.jl
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)
axislegend(ax, [[l, s]], ["Waveform Data"])
fig
endDefining a Neural Network Model
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but in case you do use them, make sure to mark them as const.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,
const nn = Chain(
Base.Fix1(fast_activation, cos),
Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-3.5609665f-7; 1.4296613f-6; 5.735454f-5; -4.2540625f-5; -0.00014098312; -0.00012756168; 2.8947104f-6; 7.979021f-6; -3.2569835f-6; 0.00015863123; 6.488775f-5; 2.0026808f-5; 2.0812891f-5; -3.4645975f-5; 1.7330387f-5; 0.0002610456; -3.8292765f-5; 2.9562698f-5; -9.55768f-5; -7.8016725f-5; 6.171631f-5; -0.0001984115; -0.00017896484; -0.00022969715; 8.1427475f-5; 4.7045607f-5; -0.00013377974; 0.00014182167; 9.19006f-5; -2.362098f-5; 1.2974027f-5; -4.8360183f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-9.4721676f-5 7.5405165f-5 -0.00016020963 7.005652f-5 6.8757254f-5 3.701062f-5 -0.00015945877 -1.679246f-5 -2.517982f-5 0.00022908978 -8.126443f-5 -4.450441f-6 0.00012502873 -0.000118430224 3.0324016f-5 1.0511999f-6 -0.0001041963 0.00010089816 0.00013129211 -0.00019790005 -6.413537f-6 0.00018907122 -0.00016956602 9.031603f-5 0.00011092242 -0.00019756131 3.180817f-5 -0.00011021083 -2.0500445f-5 -0.00020167824 -3.19393f-5 -0.00026134032; 1.8045718f-6 0.00013255607 0.00012124404 -0.00016649609 0.00023776363 4.414413f-6 4.9804465f-5 0.00011099546 9.306089f-5 -0.0001091609 6.1558065f-5 9.69988f-6 2.5101097f-5 0.00026496116 7.619028f-5 9.148411f-5 1.54531f-5 -2.863757f-5 8.384889f-5 -2.5391297f-5 0.00014740512 -3.1523054f-5 -6.594316f-5 8.701418f-5 1.5303563f-5 -2.9741632f-5 -0.00021258864 9.666014f-7 -4.99992f-5 0.00018538292 3.2130465f-5 -0.00023650296; -8.2962826f-5 1.2735087f-5 -8.75204f-5 -9.424416f-6 5.8977952f-5 -1.8923125f-5 -0.00014544075 -0.00011883932 5.5563225f-5 0.00017290402 -9.5028874f-5 0.00016190992 5.1859264f-5 -4.7727648f-5 -0.00011378697 -8.802828f-5 -3.8837352f-5 2.0046227f-5 1.2762984f-5 -1.9771518f-5 -0.00012647442 -0.00015816146 -9.719832f-5 0.0001536185 7.0264665f-5 2.2607303f-6 0.00019388061 5.463546f-6 -9.182857f-5 0.00025801614 -3.264012f-5 4.775794f-5; 1.793082f-5 -1.3674482f-5 -2.152599f-5 2.7234792f-5 -0.000120371864 0.00011523126 -8.371415f-5 0.000101372665 -0.00011709626 -1.6846088f-5 -7.957964f-5 5.387739f-5 -0.00027420526 5.680357f-5 -8.7209235f-5 2.5287201f-5 -4.752545f-5 -1.9039348f-5 -2.4763105f-5 8.3759085f-5 -2.338006f-5 1.11724785f-5 -4.9289276f-5 -0.00018166796 6.0163005f-5 -2.649804f-5 0.00016111897 -1.8784136f-5 -0.00016997507 0.00020500955 5.917554f-5 -5.8405767f-5; -3.4163975f-5 -1.77275f-5 -5.0760922f-5 1.6724016f-6 -6.0226146f-5 -3.8957412f-5 1.690836f-7 -6.009451f-5 2.1410755f-5 -3.626747f-5 -4.041142f-5 0.0001638136 8.1158694f-5 6.847542f-5 0.00016058164 4.1716754f-5 0.00013086661 5.9777332f-5 -0.00012238734 0.000108197026 0.00010373531 -0.00018038678 2.5168038f-5 6.4299275f-5 -7.521649f-5 3.448114f-5 0.000138311 -6.581654f-5 -8.336239f-5 4.8028767f-5 5.0219594f-5 0.00010138068; 6.377056f-5 -3.9480943f-5 2.6828331f-5 -1.5962216f-5 0.00021796822 -4.2900756f-5 0.00017478608 0.00010248549 -4.0854484f-5 -4.6399196f-5 -9.345008f-5 -4.0358704f-5 0.0001714792 2.0782034f-5 0.00021278166 9.6370444f-5 -3.2497366f-5 -3.5557066f-5 -3.1270385f-5 1.1384622f-5 2.4592258f-5 -8.563204f-5 -5.870181f-5 -3.9583047f-5 0.00011484236 4.1934603f-5 -7.3367715f-5 2.9611716f-5 2.7879836f-5 -0.00010616357 -1.0041319f-5 -2.761316f-5; -1.7770757f-5 4.4246975f-5 2.5150734f-5 -0.00013702817 -0.000101266174 -4.996962f-5 -0.00011802794 0.00021078762 0.00010759972 8.142844f-5 1.49138505f-5 -8.962188f-5 0.0001390288 5.5195105f-5 -0.00032075727 7.810757f-5 0.0001448369 -7.219841f-7 8.713682f-5 -4.9791695f-5 0.00012914673 1.7317996f-6 9.396489f-6 -4.7146357f-5 -0.00017026566 4.2327505f-5 -0.00025710283 -7.645546f-5 0.00017031276 0.0002092774 3.7056845f-5 -0.00019672052; -0.00026314173 -4.5667206f-5 5.2313637f-5 -0.00017406854 0.00031684476 2.7306776f-5 0.00014789308 2.8405404f-5 -1.2066691f-5 3.5333742f-5 5.0135277f-5 -5.332376f-5 -7.242823f-5 -1.136005f-5 0.00012258039 3.3790486f-5 0.000100736885 -0.00015202102 -8.896023f-6 3.9256363f-5 8.613521f-5 -4.486434f-5 -1.3866278f-6 -3.6461268f-5 -0.00010072548 5.99767f-6 9.417521f-5 -0.00014406304 0.00012318401 -0.00015912607 -0.00010304168 0.00019566917; 0.000106508305 -0.00016981 -1.6457705f-5 -4.185373f-5 -8.285721f-5 -8.975269f-5 -9.223424f-5 -0.00011571371 7.948685f-5 -9.246045f-5 -7.2397146f-5 -6.0720754f-6 -0.00014196358 -7.547941f-5 -3.5701993f-5 1.59997f-5 -7.004262f-5 1.8029532f-5 -4.4611817f-5 -8.9319525f-5 0.00017569646 0.00013813219 -4.181921f-5 -0.00024276206 -2.1852015f-6 -2.9688777f-6 0.00018602023 9.064932f-5 1.0486815f-5 -1.8725668f-5 -0.00011398758 -2.4228519f-5; 4.5900882f-7 -9.2290145f-5 8.83688f-6 -0.00013861842 -2.1899301f-5 0.00015166166 0.00012348325 1.45763215f-5 -5.4116852f-5 -4.5822442f-5 6.9552852f-6 1.6376887f-5 -1.5571306f-5 -3.891935f-6 3.0407924f-5 0.00010875831 -0.00012184081 7.879402f-5 -7.175914f-7 7.001282f-5 -5.8480877f-5 3.689482f-5 0.00016568201 -5.6494137f-5 -0.00011082164 -0.00017194514 0.00016411899 -2.8384016f-5 0.00010892312 0.00013858662 -8.515984f-5 9.671691f-5; -7.756533f-5 1.4505203f-5 -0.000100332116 4.5197423f-5 3.619255f-6 -9.0516274f-5 -5.853295f-5 -0.00012982583 -1.9560228f-5 3.2787033f-5 9.019073f-5 0.00022018807 4.864092f-5 -3.155245f-6 -9.106583f-5 0.00023125451 -4.144492f-5 7.6393015f-5 -6.962529f-5 1.1774564f-6 -2.4589996f-5 5.157022f-5 1.3792019f-6 6.016714f-5 9.322034f-5 -0.00022091561 -0.0002207782 5.2217674f-5 0.00011952616 -2.7391943f-5 -5.430463f-5 9.132611f-5; 2.656112f-5 -1.07781125f-5 2.8550863f-5 -8.734941f-5 5.423906f-5 8.448873f-6 1.1826801f-5 9.747937f-5 -4.6230634f-6 -0.00019159418 9.132126f-5 7.994687f-5 6.564849f-5 0.00012872543 -9.3470335f-5 -1.7187913f-5 -0.00010515723 6.6567954f-5 -0.00013204978 -3.2470427f-6 -0.00020976504 7.712304f-5 -5.5932393f-5 2.3944705f-5 9.721864f-5 -3.730731f-5 -4.1138264f-5 -3.197661f-5 4.2369833f-5 -1.7044475f-5 4.1874573f-5 -0.00014527494; -4.0330115f-6 1.2867798f-5 0.00013959929 -2.5063107f-5 -0.00016140223 -1.2558549f-5 7.4822674f-5 6.476299f-5 6.583076f-5 -1.51472f-5 -4.71935f-5 2.993625f-5 -0.00014161291 -8.447247f-5 4.4095574f-5 -1.5399353f-7 6.690964f-5 5.302521f-6 -0.00010153969 1.5050651f-5 -2.6463362f-5 4.7069952f-5 0.0002055194 -8.0744016f-5 4.6173558f-5 -2.4565961f-5 -9.8287346f-5 -0.00012839632 -6.607917f-5 -1.5994432f-5 -0.00016291336 -9.9512465f-5; 0.00016269444 8.2220045f-5 6.598376f-5 2.1923846f-5 0.00016464616 9.5544936f-5 -0.00018172804 -0.00021240492 3.2751865f-5 0.0001256819 -0.00013611549 0.00017477217 -3.01007f-5 -8.72018f-5 7.859864f-5 0.0001268572 -9.353054f-5 1.4543239f-5 0.00013108729 0.00019260052 -5.810617f-5 -4.128591f-6 -0.00016568706 0.00014768145 -3.2961234f-5 -0.00024017734 -0.00020638104 -6.3245636f-5 7.915442f-5 2.4796798f-5 -4.8726f-6 -0.00012305831; -0.00015758275 -0.00013437943 4.229262f-5 -7.314669f-5 0.00011123302 1.0980783f-5 7.566389f-5 5.5344084f-5 -0.00010070093 -4.6227487f-5 -6.64117f-5 9.434421f-5 9.920567f-6 4.5812467f-5 0.00012507866 -1.5649191f-5 -6.410805f-5 3.96057f-5 4.4554663f-6 0.00011362418 -2.3506082f-5 -9.669833f-5 -4.171652f-5 0.00011319593 -6.665713f-5 -6.801631f-5 -8.116993f-5 -0.00016853126 -9.266238f-5 6.3813513f-7 -2.0343206f-5 -5.8070887f-5; 9.447887f-5 2.9120347f-5 -0.00014080672 -3.7403565f-6 5.8971505f-5 -1.10936435f-5 8.8796965f-5 0.00012089685 -1.1048531f-5 -3.4556306f-5 -0.00012677656 1.0936487f-5 4.014813f-5 -3.7379767f-5 9.348715f-5 0.00014645074 -6.6371416f-5 3.721095f-5 8.3750136f-5 4.84943f-6 0.00029084552 -0.00015203095 -1.9854953f-5 -1.2665714f-5 0.00010003649 -1.33840485f-5 -5.5852364f-5 3.69664f-5 -6.205709f-5 -4.93924f-5 -8.159049f-5 8.026602f-5; -8.492526f-7 1.3159706f-5 -0.00010351275 0.000115885465 1.6445181f-5 0.00023365066 -0.00016762629 -4.5354347f-5 0.00011810787 0.00012816672 -5.253855f-5 7.672162f-5 1.6679522f-5 0.00021453694 -2.8668768f-5 6.460556f-5 8.3641506f-5 -4.4751926f-7 -0.00021001043 -5.187617f-5 -5.573318f-6 -5.1476716f-5 0.0001448255 -1.883018f-5 1.0824038f-5 0.00017652035 5.7147714f-5 1.082318f-5 -0.00021353854 -3.6153164f-5 3.6464313f-5 6.3124484f-5; -1.5837613f-5 5.416549f-5 0.000108349224 -1.7027845f-5 3.5834644f-6 0.00013723553 1.2566498f-5 -7.475301f-5 -3.890306f-6 -7.074255f-5 -4.1975814f-5 -5.937939f-5 -0.00012835402 -1.7850267f-5 -2.2589502f-5 -8.43927f-5 -8.962627f-5 -2.6966092f-5 0.000118935975 0.00013848774 3.1758589f-6 -5.3247444f-5 9.645976f-6 -9.3429495f-5 -4.1493444f-5 -0.00023068151 -6.283152f-5 3.2295604f-5 6.0976374f-5 0.00015592492 0.00014617265 4.2517888f-5; 0.00027620967 -3.9248484f-5 -7.3709954f-5 -6.349018f-5 3.0320423f-5 0.0001476996 3.0618925f-5 2.5645557f-5 4.7046988f-6 4.144849f-5 2.7989312f-5 1.0054458f-5 3.5526602f-5 -1.5137679f-5 -2.9675846f-5 -5.258374f-5 5.099835f-5 8.689004f-7 -1.3234625f-5 3.3462024f-5 -2.3912615f-5 0.00024120949 5.7682115f-5 -3.367046f-5 -8.854429f-5 -6.833892f-5 2.6980722f-5 8.709771f-5 -7.643344f-5 -5.7833364f-5 -0.00018403913 8.93749f-5; 0.00010376295 -3.6407917f-5 3.1401327f-5 -6.481887f-5 -1.4436474f-6 1.554844f-5 6.3915344f-5 -0.00037255278 -0.00012627358 -5.2338393f-5 3.5643185f-5 -8.231645f-5 -6.283756f-5 8.2433726f-5 7.13309f-5 -8.571021f-5 0.00028041503 -0.00011740175 -0.000104828556 -4.6562833f-5 -0.00010396245 0.000101344296 -4.8438193f-5 2.3544279f-5 4.854186f-6 -1.930549f-5 -8.307096f-5 0.00018006757 5.2972788f-5 -8.1874205f-6 2.3591423f-5 -0.0001148228; 3.6913298f-5 0.0001229591 0.00010614159 0.0001762486 7.1915165f-5 0.00016197866 -0.00012998277 2.8674229f-5 -4.6859517f-5 -9.668344f-5 -6.4915366f-5 -1.8912599f-5 -0.00020305217 8.415766f-5 0.00013445292 0.00019480867 0.00011383066 -1.4143321f-5 1.8310733f-5 0.00016948764 -7.365366f-6 3.900732f-5 3.0655236f-5 -6.836324f-5 8.039573f-8 -9.766876f-6 -0.00022457604 -7.325278f-5 -3.21232f-5 9.334748f-5 -6.5286696f-7 0.0002200015; -3.7821148f-5 3.47674f-5 8.0442815f-6 -5.1646393f-6 0.00013633838 0.00013826044 -9.969484f-6 0.00011684983 -9.3660834f-5 -3.7903617f-5 0.0001357111 -0.00013638096 2.199275f-5 -5.5966604f-5 8.421477f-5 -1.2592114f-5 1.9416616f-6 -0.0001354088 0.00014669557 3.3248211f-6 -3.299985f-5 -6.1509754f-6 -2.1261456f-5 -0.00019095321 -8.047801f-5 2.966328f-5 -6.403807f-5 -1.1932758f-5 -9.852104f-5 -0.000106636115 0.00021849544 2.6876349f-5; -2.9763014f-5 -6.1316055f-6 0.00010967284 -9.91625f-5 5.3679356f-5 -8.5052267f-7 0.00013162251 -7.124252f-5 2.5510268f-5 1.7571909f-5 3.765433f-5 3.2075422f-5 0.00011027888 -0.00015442437 0.00010947939 -0.00020866073 -0.00027615207 6.214988f-5 -2.96233f-5 -2.048819f-5 7.6755074f-5 -0.00011727608 1.5953876f-5 9.649266f-5 -3.6845504f-5 -0.00011437054 -0.00011720264 -5.5687666f-5 -6.399861f-5 -0.00013549929 7.949848f-5 -0.0001378104; -2.3987895f-5 -7.289566f-5 -0.00010114774 -0.00011049174 9.235256f-7 -4.3243825f-5 -6.5840635f-5 2.1081436f-5 -7.110702f-5 5.783613f-5 8.113431f-5 8.312855f-6 4.3820364f-5 -5.543572f-5 5.3604843f-5 -7.142941f-5 3.1883286f-5 7.725977f-5 -1.9529893f-5 -6.610236f-5 -0.00015132429 0.0002205234 -5.7386464f-5 3.3433345f-5 -0.00015761584 -0.00011231137 0.00020389783 -9.7893635f-6 2.1324147f-5 6.43104f-5 -3.5982095f-5 0.00020525052; -0.00012704627 -1.27062285f-5 -0.00014224669 -8.995297f-6 -0.00010389015 8.081772f-5 -0.00021548284 -5.0371636f-6 -3.685107f-5 -9.220424f-5 -0.0001405776 -9.03483f-5 9.749299f-5 -0.0001598919 2.3105764f-7 3.974332f-5 2.830522f-5 1.445772f-5 0.00026255316 -3.3108037f-5 0.000114365124 2.133251f-5 9.390646f-7 8.597302f-5 4.4832537f-5 -0.00014640631 4.605977f-5 -3.259721f-5 -3.3411165f-5 -5.533442f-5 0.000119333024 -0.00016218764; 6.802138f-5 -0.00014965044 4.0020455f-5 -1.1039857f-5 7.1190625f-5 -2.4735142f-5 -8.522056f-5 -0.00017710742 -0.00013092512 6.844086f-5 -6.105174f-5 -0.00011972776 -0.00010891511 -7.82705f-5 7.8629164f-5 5.7108555f-6 1.3864485f-6 8.165053f-5 0.000106433996 2.0753885f-5 9.380462f-6 -0.0002677551 3.6972826f-6 6.113414f-5 -2.4676856f-5 5.3898108f-5 -6.569027f-5 -3.968923f-5 -0.00013129874 -0.000102500395 8.092122f-5 5.704277f-5; -0.00019121471 4.120586f-5 2.5368441f-5 9.491956f-5 2.5218977f-5 -8.651238f-5 9.647222f-6 -3.1341628f-5 9.225955f-5 -0.00018859463 -0.00012135668 -5.1857744f-5 -0.00018005872 8.144036f-5 -0.00011253409 7.0428425f-5 -1.2069114f-5 -0.00035696127 -8.832393f-5 -5.9433547f-5 -0.00029694627 -0.00012421954 8.795808f-5 -3.1825723f-5 -6.147895f-5 -0.000115042334 2.0978712f-5 5.2176412f-5 0.00013537903 -8.725394f-5 0.00011795276 7.551107f-5; 0.00014442831 3.6347566f-5 -1.0131888f-6 2.3611945f-5 -4.6088513f-5 -0.00011724409 -0.00027096373 0.000221842 -2.5976782f-5 -0.00016725456 4.859497f-5 -4.76884f-5 -3.241124f-5 0.00012698925 -7.549256f-5 -0.00016495539 8.248876f-5 -6.963702f-6 -0.00011071092 -6.30624f-5 -6.483337f-5 5.036446f-5 9.271626f-5 8.215085f-5 -1.1273062f-5 3.132929f-5 9.571084f-5 2.3133522f-5 -0.00029147835 -8.890626f-5 -4.596219f-5 -3.8792663f-5; 0.00015148266 0.00018308489 1.6285765f-5 0.00012868048 5.6193752f-5 9.05283f-5 -8.338846f-5 -5.5408524f-5 2.4316812f-5 -0.00013931195 -3.6767276f-5 2.3714832f-5 -0.00014210993 -1.9002196f-5 4.3610293f-5 1.079502f-5 -0.00010347235 -5.085065f-6 -7.057893f-5 -0.00020427756 -0.00018110809 0.0001084049 2.8269222f-5 -0.00019040906 4.2990963f-5 1.527531f-5 6.363231f-5 0.00010667375 0.00015261168 0.0001169233 -5.1678544f-5 6.968354f-5; 8.5498636f-5 7.99024f-5 -1.4908019f-5 -6.046004f-5 0.000101323174 0.00022266283 0.0001302353 -0.00014010066 -1.8937382f-6 0.00013134758 -0.00014602781 -2.6318436f-5 -9.510272f-5 -3.7440575f-5 -7.040526f-5 7.12034f-6 -0.00026122914 0.00020557042 -6.88432f-5 8.528619f-5 0.00016476003 -2.058454f-5 7.915418f-5 9.782308f-5 -9.304837f-5 -0.00011126425 9.265743f-5 -0.00017446461 7.969163f-5 0.00027671258 6.014508f-6 0.00018513532; 4.6740017f-5 -0.000111192574 7.3510346f-5 0.00011231445 0.0002020917 4.862975f-5 -0.0001492833 -0.00010839264 -0.00017475238 -4.040398f-6 0.00021343923 -4.599124f-5 3.92206f-5 -0.00014661928 -3.1543863f-5 3.077792f-5 -0.00021943511 -0.0001134141 1.8319772f-5 -0.0001418251 0.00011447056 0.00010144557 -2.028438f-5 -9.2176604f-5 0.00019833498 3.501238f-5 4.2954933f-5 0.00019390204 8.750545f-5 -2.0903666f-5 -9.844146f-6 -1.3439047f-5; 0.00010862108 -5.4049855f-5 3.1523352f-5 -5.229407f-7 -5.646683f-5 9.3628674f-5 -2.3723242f-5 -3.0227573f-5 -0.00012491102 -3.8090122f-5 2.6423735f-5 -1.7008828f-6 -0.00014012578 9.925401f-6 -9.623898f-5 6.002358f-5 -2.6363712f-5 8.353125f-5 6.0895927f-5 -9.6978736f-5 0.00014134769 -1.523072f-5 5.0812618f-5 -1.7103204f-5 1.949806f-6 7.516954f-5 0.00012835025 7.392135f-5 -5.843239f-5 4.5119745f-5 -6.0419236f-5 -9.767831f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[-5.1723284f-5 -5.5606462f-5 -2.4906276f-6 4.716476f-5 -7.044862f-5 -0.00015437588 7.139034f-5 -3.5005374f-5 0.000101318656 -1.1782184f-6 -8.3712f-5 -4.753622f-5 -3.8130256f-5 -0.0001821545 0.00012835315 -2.749064f-5 -3.760277f-5 7.940098f-5 3.2928608f-6 -8.690653f-5 0.00014309613 6.0535735f-5 -0.00013417762 2.14173f-5 -7.7093486f-5 1.289954f-5 8.431952f-5 -1.7344386f-5 4.1332452f-5 -1.281316f-5 7.310865f-5 -0.00017549275; -3.2488322f-5 5.8646834f-5 -9.536376f-5 -0.00017939266 0.00020423727 -1.9998074f-6 5.999138f-5 0.00016413705 3.2045777f-5 -1.6727443f-5 -9.364735f-5 -0.00013969475 0.00015848491 -0.0001390261 -7.0130896f-5 7.053921f-5 -2.0182437f-5 7.636344f-6 1.8289767f-5 0.00013951285 -4.9108523f-5 -8.654883f-5 -6.470531f-5 4.031347f-5 -9.151338f-5 4.8044803f-6 -0.000104191226 -4.7172427f-5 -2.6058271f-5 0.00010375413 -0.00015031644 1.9016326f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))Similar to most deep learning frameworks, Lux defaults to using Float32. However, in this case we need Float64
const params = ComponentArray(f64(ps))
const nn_model = StatefulLuxLayer(nn, nothing, st)StatefulLuxLayer{Val{true}()}(
Chain(
layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
layer_2 = Dense(1 => 32, cos), # 64 parameters
layer_3 = Dense(32 => 32, cos), # 1_056 parameters
layer_4 = Dense(32 => 2), # 66 parameters
),
) # Total: 1_186 parameters,
# plus 0 states.Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses
where,
function ODE_model(u, nn_params, t)
χ, ϕ = u
p, M, e = ode_model_params
# In this example we know that `st` is am empty NamedTuple hence we can safely ignore
# it, however, in general, we should use `st` to store the state of the neural network.
y = 1 .+ nn_model([first(u)], nn_params)
numer = (1 + e * cos(χ))^2
denom = M * (p^(3 / 2))
χ̇ = (numer / denom) * y[1]
ϕ̇ = (numer / denom) * y[2]
return [χ̇, ϕ̇]
endLet us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
axislegend(
ax,
[[l1, s1], [l2, s2]],
["Waveform Data", "Waveform Neural Net (Untrained)"];
position=:lb,
)
fig
endSetting Up for Training the Neural Network
Next, we define the objective (loss) function to be minimized when training the neural differential equations.
const mseloss = MSELoss()
function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
return mseloss(pred_waveform, waveform)
endWarmup the loss function
loss(params)0.0006935281318993377Now let us define a callback function to store the loss over time
const losses = Float64[]
function callback(θ, l)
push!(losses, l)
@printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
return false
endTraining the Neural Network
Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
optprob,
BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
callback,
maxiters=1000,
)retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-3.5609664905634165e-7; 1.4296613244352379e-6; 5.735454033128593e-5; -4.254062514519105e-5; -0.00014098311657997185; -0.00012756168143803135; 2.894710405595401e-6; 7.979020665510474e-6; -3.256983518439263e-6; 0.00015863122825965391; 6.488774670284957e-5; 2.0026807760573948e-5; 2.0812891307268697e-5; -3.464597466513417e-5; 1.7330386981468336e-5; 0.0002610456140242203; -3.829276465693236e-5; 2.956269781858862e-5; -9.557679732099874e-5; -7.801672472842521e-5; 6.171630957386165e-5; -0.0001984114933293277; -0.00017896483768704082; -0.00022969715064389344; 8.142747537931629e-5; 4.7045607061539825e-5; -0.00013377974391924976; 0.00014182167069521777; 9.190059790849956e-5; -2.3620979845827615e-5; 1.2974027413300974e-5; -4.8360183427488304e-5;;], bias = [-5.057426989019382e-19, 1.872209940983256e-18, 4.054095512867685e-18, -2.417864339814074e-18, -2.302431643547116e-16, -2.34156719210781e-16, 3.8395459380794065e-18, 6.4327044974498374e-18, -1.3006949339924972e-18, 1.2696841169539576e-16, 3.472898679060274e-17, 1.4820500178884255e-17, 2.0748350085871053e-17, -5.442684327740905e-17, 1.845207735539708e-17, 4.631736874015142e-16, -1.4025780103987914e-17, 3.475135399876685e-17, 2.98599146825312e-17, -1.0058719770948682e-16, 1.2028805806931363e-16, -5.784807443348925e-18, -5.623675716055018e-17, -2.0708506223714276e-17, -2.0768979224393594e-18, 5.920403018017202e-17, 1.1530642632974863e-16, -9.257516968890589e-18, 4.023057382584957e-17, -5.187976342745069e-17, -9.551274449418655e-18, -1.0025629597496272e-16]), layer_3 = (weight = [-9.472417635634532e-5 7.540266475616032e-5 -0.00016021213135473153 7.005402100248418e-5 6.875475325133422e-5 3.70081203661434e-5 -0.0001594612670808657 -1.679496138904695e-5 -2.5182319709807538e-5 0.00022908728230023955 -8.126692895102892e-5 -4.452941389268545e-6 0.00012502623091194823 -0.00011843272417109353 3.0321515318695783e-5 1.0486993821860198e-6 -0.00010419879739429673 0.00010089566300366154 0.00013128960805688808 -0.0001979025522226044 -6.416037499823435e-6 0.00018906871506690475 -0.00016956851808143015 9.031353207750838e-5 0.0001109199171557519 -0.0001975638127399217 3.18056687908301e-5 -0.00011021332951285156 -2.0502945822316688e-5 -0.000201680738732926 -3.194180074870564e-5 -0.0002613428199177215; 1.8078753943821294e-6 0.000132559377009505 0.00012124734384332382 -0.00016649278368941747 0.00023776693014285197 4.41771671164489e-6 4.9807768293879464e-5 0.0001109987646420388 9.30641912352196e-5 -0.00010915759740540733 6.156136841372088e-5 9.70318378012307e-6 2.5104400622415378e-5 0.0002649644597046898 7.619358111032207e-5 9.148741118461857e-5 1.5456404428452085e-5 -2.863426708000385e-5 8.385219588734108e-5 -2.538799308905615e-5 0.00014740842779491573 -3.1519749979773754e-5 -6.593985709273041e-5 8.701748483202251e-5 1.5306867128570393e-5 -2.9738328164339113e-5 -0.00021258533489101685 9.699050612374215e-7 -4.999589473965604e-5 0.00018538621988472423 3.21337687570968e-5 -0.0002364996613389935; -8.296318328866632e-5 1.2734729622752509e-5 -8.752075582669905e-5 -9.424773182399722e-6 5.8977594832688186e-5 -1.89234821875271e-5 -0.00014544110372589797 -0.00011883967881949515 5.556286787104896e-5 0.0001729036660406592 -9.502923121835885e-5 0.00016190956311838809 5.185890704233308e-5 -4.77280052138667e-5 -0.00011378732661178228 -8.802863949603437e-5 -3.883770905595153e-5 2.004587012997942e-5 1.276262655374105e-5 -1.9771875216246703e-5 -0.00012647477770151425 -0.0001581618204938378 -9.719867439634622e-5 0.00015361814375574486 7.026430766636685e-5 2.2603731365963006e-6 0.00019388025184234917 5.463188969590579e-6 -9.182893036559023e-5 0.0002580157809432421 -3.2640477641689464e-5 4.775758434043397e-5; 1.792907624230817e-5 -1.3676226623192393e-5 -2.1527734312528713e-5 2.7233046970378733e-5 -0.00012037360847560868 0.00011522951699927068 -8.371589337575618e-5 0.00010137092049599261 -0.00011709800146154497 -1.6847832918105377e-5 -7.958138139319737e-5 5.387564470193002e-5 -0.00027420700480097315 5.680182385158424e-5 -8.721097965126278e-5 2.5285456835029044e-5 -4.752719377965953e-5 -1.9041092124370327e-5 -2.47648496918319e-5 8.375734051124167e-5 -2.338180474571106e-5 1.1170733890798785e-5 -4.929102051053619e-5 -0.00018166970239350756 6.016126081131405e-5 -2.6499785414033214e-5 0.0001611172208901838 -1.878588063508437e-5 -0.00016997681650199822 0.00020500780538322994 5.917379694774386e-5 -5.8407511638777825e-5; -3.416200862111003e-5 -1.772553305917829e-5 -5.0758955428224706e-5 1.6743681889277321e-6 -6.022417956692933e-5 -3.8955445877132676e-5 1.7105019972456542e-7 -6.009254294176069e-5 2.1412721519396263e-5 -3.6265502519301994e-5 -4.040945505684924e-5 0.00016381556842398395 8.116066082302972e-5 6.847738613296743e-5 0.00016058360260367896 4.171872050652113e-5 0.00013086857715543788 5.977929894734639e-5 -0.0001223853739312298 0.00010819899243224538 0.000103737273359823 -0.00018038481148971412 2.5170004203466996e-5 6.430124112090927e-5 -7.521452581122425e-5 3.4483108077351066e-5 0.000138312973188379 -6.581457065031715e-5 -8.336042602159108e-5 4.803073379016966e-5 5.0221561007588616e-5 0.00010138264978459474; 6.377244947799006e-5 -3.94790534002754e-5 2.6830220788547498e-5 -1.5960326224734363e-5 0.00021797010934532777 -4.289886625629472e-5 0.00017478797029326195 0.00010248737610964557 -4.0852594126765724e-5 -4.6397306748267144e-5 -9.344818887911544e-5 -4.035681401291431e-5 0.000171481091213361 2.078392365802461e-5 0.00021278354482347586 9.637233391255093e-5 -3.249547650484003e-5 -3.555517675267131e-5 -3.126849539257696e-5 1.1386511444291886e-5 2.459414802870468e-5 -8.563015059482432e-5 -5.8699921525853166e-5 -3.9581156913507964e-5 0.00011484425134113103 4.193649304652905e-5 -7.336582554417386e-5 2.9613605632033665e-5 2.788172581204145e-5 -0.00010616168062114666 -1.0039429314194912e-5 -2.761126985860212e-5; -1.7770927416945585e-5 4.424680507209361e-5 2.5150563186530433e-5 -0.0001370283383323975 -0.00010126634454654212 -4.996979199067391e-5 -0.00011802811362148986 0.00021078745036784774 0.0001075995502730666 8.142827270485809e-5 1.491368008710676e-5 -8.962204749995632e-5 0.00013902862398818352 5.519493485617666e-5 -0.0003207574439614245 7.810740195846775e-5 0.0001448367299132816 -7.221544990852009e-7 8.713664707893765e-5 -4.9791865723177226e-5 0.000129146563875752 1.7316291802739385e-6 9.396318699788908e-6 -4.7146527623491934e-5 -0.00017026582948464653 4.2327334693896e-5 -0.00025710300117797846 -7.64556273964388e-5 0.0001703125932806011 0.00020927722350161737 3.70566746539271e-5 -0.00019672068750086947; -0.00026314144543431426 -4.5666921327949547e-5 5.231392151034681e-5 -0.0001740682555508215 0.0003168450430122772 2.7307060913868548e-5 0.000147893363641913 2.8405688676855475e-5 -1.2066406274982306e-5 3.5334026502082515e-5 5.013556164629993e-5 -5.332347700897333e-5 -7.242794436461953e-5 -1.1359765271492921e-5 0.00012258067090004955 3.3790770426258556e-5 0.00010073716930054389 -0.00015202073609981416 -8.895738578588202e-6 3.925664804700048e-5 8.613549379805343e-5 -4.486405578501615e-5 -1.3863431193527759e-6 -3.6460983214799636e-5 -0.0001007251985400335 5.997954754263844e-6 9.417549244534335e-5 -0.0001440627593227095 0.00012318429889563668 -0.0001591257814696404 -0.00010304139778310159 0.000195669455020095; 0.00010650501399333501 -0.0001698132861728323 -1.6460996173808497e-5 -4.1857022277218165e-5 -8.286049951298915e-5 -8.975598275161558e-5 -9.223752813148398e-5 -0.00011571699969670817 7.9483555329823e-5 -9.246373765368203e-5 -7.240043732770356e-5 -6.075366580154964e-6 -0.00014196687215906612 -7.548270032009999e-5 -3.5705283787940255e-5 1.5996408639873717e-5 -7.004590833992628e-5 1.8026240743091386e-5 -4.461510861608241e-5 -8.932281589102368e-5 0.00017569317178833617 0.0001381289014736137 -4.182250149628533e-5 -0.00024276535504925092 -2.188492724485299e-6 -2.9721688387143884e-6 0.00018601694073562476 9.064603118726256e-5 1.0483523785412554e-5 -1.872895945587689e-5 -0.000113990873340192 -2.4231809855227597e-5; 4.601934131730411e-7 -9.228896080680016e-5 8.838064311529058e-6 -0.00013861723120720952 -2.1898116322016678e-5 0.0001516628458171479 0.00012348442985745325 1.4577506115128614e-5 -5.4115667373880855e-5 -4.5821257592661835e-5 6.956469840186305e-6 1.6378071884500524e-5 -1.5570121397504864e-5 -3.890750597239118e-6 3.0409108583225642e-5 0.00010875949235699478 -0.00012183962237253666 7.879520715283373e-5 -7.164067801060766e-7 7.001400670338364e-5 -5.847969216342056e-5 3.6896004965040424e-5 0.0001656831941341316 -5.6492951919119803e-5 -0.00011082045377628957 -0.0001719439547036393 0.00016412017291945823 -2.838283173555179e-5 0.00010892430734889429 0.0001385878044653141 -8.515865331213753e-5 9.671809537172403e-5; -7.756531761847957e-5 1.4505216646596714e-5 -0.00010033210185943693 4.519743685372585e-5 3.619269015273844e-6 -9.051626023725202e-5 -5.8532936812895895e-5 -0.0001298258145049266 -1.9560213908022818e-5 3.2787046953983876e-5 9.019074729289337e-5 0.0002201880846463092 4.8640932951769254e-5 -3.1552309643820275e-6 -9.106581331585123e-5 0.0002312545251391773 -4.1444906522571654e-5 7.639302872698234e-5 -6.962527607689204e-5 1.1774703728797742e-6 -2.4589981543301235e-5 5.157023348723932e-5 1.3792158518689796e-6 6.016715575819291e-5 9.322035418003275e-5 -0.00022091559424885135 -0.00022077818051334993 5.2217688213516325e-5 0.00011952617148048652 -2.7391929173661013e-5 -5.430461590881979e-5 9.132612409879863e-5; 2.656021045175088e-5 -1.0779021769960539e-5 2.8549953847576777e-5 -8.735031567277626e-5 5.423814966696611e-5 8.447963616966296e-6 1.1825891513469968e-5 9.747846363814259e-5 -4.623972719966561e-6 -0.00019159509226675624 9.132034777201542e-5 7.99459596490832e-5 6.56475824397402e-5 0.00012872451791543253 -9.347124411953897e-5 -1.7188822104890364e-5 -0.00010515814189795653 6.656704520344468e-5 -0.00013205069242041577 -3.247952028796385e-6 -0.0002097659523971024 7.712213139519968e-5 -5.5933301857396e-5 2.3943796102589573e-5 9.721772969704025e-5 -3.730821824512672e-5 -4.113917363803597e-5 -3.19775206408992e-5 4.236892366669098e-5 -1.7045383876504345e-5 4.187366378380629e-5 -0.00014527585132975357; -4.0347763025801664e-6 1.2866033593889418e-5 0.0001395975228113822 -2.506487222504874e-5 -0.0001614039978996055 -1.2560314153371633e-5 7.482090908509967e-5 6.476122829539062e-5 6.582899690314485e-5 -1.514896526727708e-5 -4.719526665226955e-5 2.993448428227412e-5 -0.0001416146737575779 -8.447423575504326e-5 4.4093808846663014e-5 -1.557583636836291e-7 6.69078758981768e-5 5.30075616537816e-6 -0.00010154145680658631 1.50488863617665e-5 -2.646512662991838e-5 4.706818757755459e-5 0.00020551764058823982 -8.074578041498928e-5 4.617179323744702e-5 -2.4567726050154792e-5 -9.828911102899888e-5 -0.00012839808592630028 -6.608093504851838e-5 -1.5996196871275603e-5 -0.00016291512698459137 -9.951422949612651e-5; 0.00016269474305408455 8.222034536265218e-5 6.598406049851678e-5 2.192414628188491e-5 0.00016464646047642645 9.554523582300394e-5 -0.0001817277427417188 -0.00021240462068272375 3.275216455767314e-5 0.00012568219405328988 -0.00013611518646394492 0.00017477246889574833 -3.0100400394081077e-5 -8.72014991127849e-5 7.859894118700555e-5 0.00012685749403861999 -9.353024346092348e-5 1.4543538569336775e-5 0.00013108759027635972 0.00019260082059735755 -5.810586866028262e-5 -4.128291251626833e-6 -0.00016568675913186678 0.00014768175039170175 -3.296093403515088e-5 -0.00024017703820755341 -0.0002063807389586212 -6.324533590851165e-5 7.915472248532175e-5 2.479709816361224e-5 -4.872300209176873e-6 -0.00012305801486400526; -0.00015758480685589406 -0.00013438148698595854 4.2290561851688624e-5 -7.314875210694419e-5 0.00011123096237599742 1.0978724732123088e-5 7.56618323986089e-5 5.5342025462877615e-5 -0.00010070298549079838 -4.6229546113747534e-5 -6.64137619413737e-5 9.434214780513654e-5 9.918508568361795e-6 4.581040821095507e-5 0.00012507660039874242 -1.5651250122428526e-5 -6.411010644923369e-5 3.9603641654239257e-5 4.45340765610677e-6 0.00011362211853439824 -2.350814020145467e-5 -9.670038661976084e-5 -4.171857972220748e-5 0.00011319387022117372 -6.665918726765611e-5 -6.801836525381897e-5 -8.117198694788161e-5 -0.0001685333186277384 -9.266444203502437e-5 6.360764938197205e-7 -2.0345265055666928e-5 -5.807294610248293e-5; 9.448050116900585e-5 2.9121978180615674e-5 -0.00014080508696367907 -3.7387252328473196e-6 5.8973136795806785e-5 -1.1092012179309358e-5 8.87985967287258e-5 0.00012089848463203984 -1.1046899423103337e-5 -3.4554674463460174e-5 -0.00012677493065582583 1.0938118642503163e-5 4.014976151067972e-5 -3.737813528148291e-5 9.348878087023487e-5 0.0001464523680928195 -6.636978502372128e-5 3.721258111311307e-5 8.375176697283724e-5 4.851061142673941e-6 0.00029084715209448745 -0.00015202931387338798 -1.985332197963808e-5 -1.266408286629037e-5 0.00010003811770131688 -1.3382417243133566e-5 -5.5850733012070674e-5 3.6968032539711284e-5 -6.205546230064926e-5 -4.939077040632985e-5 -8.158886190956238e-5 8.02676528414163e-5; -8.470311296750024e-7 1.3161927698851371e-5 -0.00010351052828916048 0.00011588768625081927 1.6447402730426116e-5 0.00023365288011411493 -0.00016762406944614296 -4.535212560873875e-5 0.00011811009112428483 0.00012816894245480767 -5.253632975943152e-5 7.672384031032723e-5 1.668174313529903e-5 0.00021453915774039442 -2.86665467239492e-5 6.46077847239028e-5 8.364372711078698e-5 -4.4529777020816513e-7 -0.00021000821057093455 -5.187394840611047e-5 -5.5710965636253105e-6 -5.147449469513713e-5 0.00014482772762567815 -1.8827958639310946e-5 1.0826259812191715e-5 0.00017652257024458989 5.714993544535884e-5 1.0825401249183467e-5 -0.00021353632241805158 -3.615094233388513e-5 3.6466534355744406e-5 6.312670534392258e-5; -1.583803606746651e-5 5.4165068998086006e-5 0.00010834880157079658 -1.7028268136814246e-5 3.583041627905826e-6 0.00013723510999869845 1.2566075549952893e-5 -7.475342937806553e-5 -3.890728582805513e-6 -7.074297247282749e-5 -4.197623713236177e-5 -5.937981115249844e-5 -0.00012835444141939213 -1.7850689634828922e-5 -2.2589924644594293e-5 -8.439312006640174e-5 -8.962669455042021e-5 -2.6966515012932148e-5 0.00011893555273007793 0.0001384873168560133 3.1754361158941487e-6 -5.3247866905902946e-5 9.645553457348561e-6 -9.342991763088059e-5 -4.1493866608392724e-5 -0.0002306819376984229 -6.283194017424868e-5 3.2295181138295405e-5 6.09759510395508e-5 0.00015592449852436854 0.00014617222694385864 4.251746502876382e-5; 0.0002762111639192097 -3.9246989717217845e-5 -7.370846064922068e-5 -6.3488686964126e-5 3.0321917120730376e-5 0.00014770109050978605 3.0620418738830145e-5 2.564705068899806e-5 4.706192567677009e-6 4.144998315559894e-5 2.7990805802539045e-5 1.0055951361858798e-5 3.552809580534234e-5 -1.5136185128800912e-5 -2.9674352435255327e-5 -5.258224477922581e-5 5.0999844871251374e-5 8.703941723191727e-7 -1.3233131129974052e-5 3.3463517366463386e-5 -2.3911120978744457e-5 0.00024121098472341398 5.76836086572227e-5 -3.3668967761800035e-5 -8.854279944834133e-5 -6.833742877396372e-5 2.6982216196033075e-5 8.709920298379357e-5 -7.643194617166107e-5 -5.7831870025699056e-5 -0.00018403764038274516 8.937639581810127e-5; 0.000103761261932635 -3.640960311927898e-5 3.139964065453155e-5 -6.482055868432823e-5 -1.4453336164883197e-6 1.554675328193603e-5 6.391365787534907e-5 -0.00037255446476806955 -0.00012627526120227825 -5.234007948487599e-5 3.5641498477749346e-5 -8.231813807684043e-5 -6.283924266641071e-5 8.243203956921585e-5 7.13292120123184e-5 -8.571189936657447e-5 0.00028041334190460866 -0.00011740343870074507 -0.00010483024197342366 -4.6564518918196746e-5 -0.00010396413380690707 0.00010134260993042284 -4.843987880334411e-5 2.3542592323497602e-5 4.852499848900584e-6 -1.930717552782445e-5 -8.30726476054639e-5 0.0001800658810749662 5.297110147480323e-5 -8.189106658716106e-6 2.3589736890829535e-5 -0.00011482448280035456; 3.69164368698922e-5 0.00012296223903331805 0.0001061447288470858 0.00017625174550130678 7.191830408781507e-5 0.00016198180273430525 -0.0001299796316793571 2.867736802227686e-5 -4.6856377962735586e-5 -9.668030216337047e-5 -6.491222703107232e-5 -1.8909459327456114e-5 -0.00020304903264831893 8.416080126562862e-5 0.00013445605472452808 0.00019481181115627713 0.00011383380078268917 -1.4140181888702373e-5 1.8313872346420708e-5 0.0001694907801670509 -7.362226738677365e-6 3.9010457471214184e-5 3.065837481196407e-5 -6.836009961767e-5 8.353493609338927e-8 -9.763737039518087e-6 -0.00022457290142985118 -7.324963772181706e-5 -3.2120060864603434e-5 9.33506194946899e-5 -6.497277518020133e-7 0.0002200046397724714; -3.78212803211573e-5 3.476726716099796e-5 8.044149135013956e-6 -5.164771655916386e-6 0.0001363382457470951 0.00013826030636620334 -9.969616703972913e-6 0.00011684969514091059 -9.366096627717348e-5 -3.790374966273418e-5 0.00013571097088926047 -0.00013638108929670857 2.1992617357238453e-5 -5.596673644548719e-5 8.421463872720046e-5 -1.2592246539938825e-5 1.9415292713456538e-6 -0.00013540893404796227 0.0001466954404436477 3.3246888138021212e-6 -3.2999983423440325e-5 -6.151107746051094e-6 -2.12615882360174e-5 -0.00019095334709207744 -8.047813888923678e-5 2.9663148333849713e-5 -6.403819997148021e-5 -1.1932890165996567e-5 -9.852117499621058e-5 -0.0001066362469741653 0.00021849530322425557 2.68762164610162e-5; -2.9765651397702295e-5 -6.134242830831336e-6 0.00010967020005400857 -9.916513574362298e-5 5.3676718543936874e-5 -8.531600384096025e-7 0.00013161987242701743 -7.124515452242174e-5 2.5507630361880175e-5 1.7569271385512608e-5 3.765169110911608e-5 3.207278512481912e-5 0.00011027624366752187 -0.00015442700880166695 0.00010947675416891767 -0.00020866336473776388 -0.0002761547110095299 6.214723924122174e-5 -2.9625936640597454e-5 -2.0490828155829922e-5 7.675243630719977e-5 -0.00011727871953559155 1.595123847430821e-5 9.649002111468641e-5 -3.6848141298918865e-5 -0.0001143731750017637 -0.00011720527601945255 -5.569030310660449e-5 -6.400124750152154e-5 -0.00013550192290504137 7.949584340158212e-5 -0.00013781304352969385; -2.398827531865677e-5 -7.289603822132915e-5 -0.0001011481232518733 -0.00011049211902277865 9.231457459896409e-7 -4.3244204464965996e-5 -6.58410150601004e-5 2.1081056144653173e-5 -7.110739683699205e-5 5.783574884558064e-5 8.113393265088729e-5 8.31247539903862e-6 4.381998438188467e-5 -5.543610099553461e-5 5.360446298877368e-5 -7.142978724291048e-5 3.188290643888199e-5 7.725939063388522e-5 -1.9530272441583207e-5 -6.610273853143945e-5 -0.00015132466946229635 0.00022052302516099208 -5.738684343955427e-5 3.3432965535077564e-5 -0.00015761621911511087 -0.00011231174871059449 0.00020389744746066504 -9.789743347166594e-6 2.1323767538748144e-5 6.431001962195745e-5 -3.5982474985899105e-5 0.0002052501352926312; -0.00012704865804466194 -1.270861940924322e-5 -0.0001422490793751784 -8.997688213131485e-6 -0.00010389253731175416 8.081533013552987e-5 -0.00021548523074074423 -5.039554538091238e-6 -3.685346092737988e-5 -9.220662906356764e-5 -0.0001405799892503992 -9.035069234727629e-5 9.749060262634674e-5 -0.0001598942952236408 2.2866674867835026e-7 3.974093050293574e-5 2.83028286381146e-5 1.4455328879701548e-5 0.00026255077244656164 -3.3110428206271046e-5 0.00011436273342638797 2.133011838540468e-5 9.366736995478179e-7 8.597062986755424e-5 4.483014634985284e-5 -0.00014640870068746223 4.605737848266714e-5 -3.2599601412062326e-5 -3.34135555144125e-5 -5.533681252331208e-5 0.00011933063359407012 -0.00016219003447389752; 6.801848088445e-5 -0.0001496533380885809 4.001755827277137e-5 -1.104275394380761e-5 7.118772795013959e-5 -2.4738038849928424e-5 -8.522346012240416e-5 -0.00017711031374700504 -0.0001309280208842098 6.84379616688043e-5 -6.105463522897775e-5 -0.00011973065681003136 -0.00010891800899470326 -7.827339992571651e-5 7.862626683719343e-5 5.707958404112239e-6 1.3835514310939242e-6 8.164763006435647e-5 0.0001064310987532025 2.075098803845203e-5 9.377564889397332e-6 -0.00026775799175433965 3.6943855649457295e-6 6.113124130585419e-5 -2.4679752972479715e-5 5.389521050903084e-5 -6.5693167347371e-5 -3.9692125359458015e-5 -0.00013130164130769012 -0.00010250329193139538 8.091832445289813e-5 5.703987211392342e-5; -0.000191219573595864 4.120100072484908e-5 2.536358021914989e-5 9.491469915972819e-5 2.5214115678876054e-5 -8.651724113132592e-5 9.642361304012384e-6 -3.1346488528489256e-5 9.225468909151367e-5 -0.00018859948670704168 -0.00012136154021119527 -5.1862604438629525e-5 -0.00018006358061667952 8.14355001902923e-5 -0.00011253895411484384 7.042356389601713e-5 -1.2073974746288975e-5 -0.0003569661354295199 -8.832879450620133e-5 -5.9438407904063305e-5 -0.00029695112664468475 -0.00012422440397763502 8.795321570963499e-5 -3.1830583454329554e-5 -6.148381419462525e-5 -0.00011504719498317828 2.097385129037315e-5 5.2171551372351834e-5 0.00013537416759663176 -8.725879945543064e-5 0.00011794789986468392 7.550620591517419e-5; 0.0001444260161105014 3.634527098880444e-5 -1.015484311335139e-6 2.360964942168875e-5 -4.609080854890949e-5 -0.00011724638528278447 -0.0002709660292868101 0.0002218397021959938 -2.597907763005404e-5 -0.00016725685776065722 4.8592675800012933e-5 -4.7690697040833994e-5 -3.241353618535816e-5 0.0001269869572450818 -7.549485229277664e-5 -0.00016495768425837357 8.248646560501712e-5 -6.965997666283397e-6 -0.00011071321562952314 -6.306469361619275e-5 -6.48356689754443e-5 5.0362163225957e-5 9.271396635712022e-5 8.214855558153759e-5 -1.1275357731514615e-5 3.1326995728579306e-5 9.570854678429434e-5 2.3131226104739426e-5 -0.0002914806499876548 -8.890855309176852e-5 -4.5964487010805594e-5 -3.879495846662201e-5; 0.00015148322866944145 0.00018308546289981823 1.6286335571840856e-5 0.00012868104689469075 5.619432280136382e-5 9.052887192874393e-5 -8.33878918137074e-5 -5.540795306378498e-5 2.4317382972277827e-5 -0.00013931137497357084 -3.676670591166774e-5 2.371540298108948e-5 -0.00014210935902601567 -1.900162593747033e-5 4.361086333583582e-5 1.0795590546956806e-5 -0.00010347177582379908 -5.0844943551952345e-6 -7.057835743116836e-5 -0.0002042769889748339 -0.00018110751685971915 0.00010840547050528906 2.8269792287660537e-5 -0.00019040848989326076 4.299153382371536e-5 1.5275879682213715e-5 6.363287987833057e-5 0.00010667432373802281 0.00015261225356434793 0.00011692387064167768 -5.167797341064075e-5 6.968411192710947e-5; 8.550127318614282e-5 7.990503587081217e-5 -1.4905381196683055e-5 -6.045740150710604e-5 0.0001013258116088741 0.00022266546516930331 0.00013023794241036408 -0.000140098021749676 -1.891100571068504e-6 0.00013135021805082256 -0.0001460251768932609 -2.6315798042727717e-5 -9.510008016348732e-5 -3.7437937810122806e-5 -7.040262056960629e-5 7.1229774746385685e-6 -0.00026122650519226685 0.00020557306202846623 -6.884055978134012e-5 8.528882977571865e-5 0.00016476266687988095 -2.0581903247349177e-5 7.915682004548756e-5 9.782571947450036e-5 -9.304573535900084e-5 -0.00011126161183909586 9.266006605484958e-5 -0.00017446197665977044 7.969426593063721e-5 0.0002767152218964982 6.017145604179129e-6 0.00018513795299242167; 4.6740614299548996e-5 -0.00011119197651010172 7.351094328600686e-5 0.0001123150440541821 0.00020209229925216557 4.8630347838974414e-5 -0.00014928269873803646 -0.00010839204250101166 -0.00017475178091285915 -4.039800467815484e-6 0.00021343982453958814 -4.599064331000843e-5 3.9221195867637304e-5 -0.00014661867961717962 -3.154326540266335e-5 3.077851662213362e-5 -0.00021943451346328677 -0.00011341350099269483 1.832036915561308e-5 -0.0001418245074894494 0.00011447115767782316 0.00010144616763840434 -2.028378284036035e-5 -9.217600662692705e-5 0.00019833557681681296 3.5012978434448855e-5 4.295553018141027e-5 0.00019390264143270816 8.750604403007448e-5 -2.0903068696727875e-5 -9.843548549045169e-6 -1.3438449139587827e-5; 0.00010862127556828335 -5.404966258430369e-5 3.152354430420788e-5 -5.227483268858513e-7 -5.646663747830889e-5 9.362886658999518e-5 -2.3723049242845556e-5 -3.022738063810841e-5 -0.0001249108286072545 -3.80899295931277e-5 2.642392747583325e-5 -1.7006903943286894e-6 -0.0001401255835742735 9.925593541255332e-6 -9.62387898551345e-5 6.0023772231201244e-5 -2.6363519726884466e-5 8.35314454081423e-5 6.08961195313787e-5 -9.697854374172652e-5 0.00014134787808069485 -1.5230527868710902e-5 5.0812810069316084e-5 -1.7103011931626492e-5 1.9499982601593925e-6 7.516973301898226e-5 0.00012835044209276697 7.392154249026796e-5 -5.843219741001294e-5 4.511993712293466e-5 -6.041904315571748e-5 -9.767811979037239e-5], bias = [-2.5005058943331606e-9, 3.3036419092906236e-9, -3.571614798981752e-10, -1.7445893715708838e-9, 1.9665939412996435e-9, 1.8896869801903373e-9, -1.703768202645391e-10, 2.8466746486187437e-10, -3.2911865515099796e-9, 1.1845964397105433e-9, 1.397560762115301e-11, -9.092874481709823e-10, -1.7648382362936387e-9, 2.999204249213081e-10, -2.0586402485034312e-9, 1.6313000977800166e-9, 2.2214852234302197e-9, -4.227488853763751e-10, 1.4938002221844733e-9, -1.6861962254262713e-9, 3.1392069150710854e-9, -1.323228280477041e-10, -2.6373710176108525e-9, -3.7984423683970593e-10, -2.39089189806616e-9, -2.897071579016578e-9, -4.8609099585921904e-9, -2.295503818272241e-9, 5.705579511256344e-10, 2.6375872188110836e-9, 5.97457753701946e-10, 1.9236830547292775e-10]), layer_4 = (weight = [-0.0007325590183909555 -0.0007364420790909398 -0.0006833265123999501 -0.0006336710554133172 -0.0007512844122215017 -0.0008352116810020592 -0.0006094455443821228 -0.0007158412600566143 -0.0005795169788322319 -0.0006820140724477965 -0.000764547890227238 -0.00072837208913826 -0.0007189660660920634 -0.0008629903866610496 -0.0005524826453750715 -0.0007083264620069355 -0.0007184385383355612 -0.0006014349007124081 -0.0006775429730658093 -0.000767742349016815 -0.0005377395356251861 -0.0006203001529244463 -0.0008150133351878385 -0.0006594185847267783 -0.0007579292332233172 -0.0006679361460942218 -0.0005965158183172278 -0.0006981801464395748 -0.0006395034277471214 -0.0006936488763191151 -0.0006077272321277387 -0.0008563286376104233; 0.0002106088214290487 0.00030174393575211836 0.00014773343976671487 6.370450900264347e-5 0.0004473344283765379 0.0002410973582504831 0.0003030885780001168 0.000407234244661799 0.0002751428846016653 0.0002263697431672327 0.00014944984652607107 0.00010340244281925206 0.00040158207917289637 0.00010407109506040639 0.00017296626730436228 0.0003136363833550481 0.00022291471839182102 0.0002507335406187176 0.00026138694639958447 0.0003826100217124398 0.00019398859521329104 0.00015654836521801125 0.0001783918265300814 0.0002834106659671881 0.00015158376948305115 0.00024790160663039465 0.00013890577816381332 0.0001959247258342341 0.0002170389243529337 0.0003468512669147317 9.278075667974179e-5 0.00026211352419764136], bias = [-0.0006808358878531535, 0.0002430971981326624]))Visualizing the Results
Let us now plot the loss over time
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")
lines!(ax, losses; linewidth=4, alpha=0.75)
scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)
fig
endFinally let us visualize the results
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
s3 = scatter!(
ax,
tsteps,
waveform_nn_trained;
marker=:circle,
alpha=0.5,
strokewidth=2,
markersize=12,
)
axislegend(
ax,
[[l1, s1], [l2, s2], [l3, s3]],
["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
position=:lb,
)
fig
endAppendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.11.8
Commit cf1da5e20e3 (2025-11-06 17:49 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 4 × AMD EPYC 7763 64-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 4 default, 0 interactive, 2 GC (on 4 virtual cores)
Environment:
JULIA_DEBUG = Literate
LD_LIBRARY_PATH =
JULIA_NUM_THREADS = 4
JULIA_CPU_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0This page was generated using Literate.jl.