Training a Neural ODE to Model Gravitational Waveforms
This code is adapted from Astroinformatics/ScientificMachineLearning
The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl
Package Imports
using Lux,
ComponentArrays,
LineSearches,
OrdinaryDiffEqLowOrderRK,
Optimization,
OptimizationOptimJL,
Printf,
Random,
SciMLSensitivity
using CairoMakieDefine some Utility Functions
Tip
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector
function one2two(path, m₁, m₂)
M = m₁ + m₂
r₁ = m₂ / M .* path
r₂ = -m₁ / M .* path
return r₁, r₂
endNext we define a function to perform the change of variables:
@views function soln2orbit(soln, model_params=nothing)
@assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
if size(soln, 1) == 2
χ = soln[1, :]
ϕ = soln[2, :]
@assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
p, M, e = model_params
else
χ = soln[1, :]
ϕ = soln[2, :]
p = soln[3, :]
e = soln[4, :]
end
r = p ./ (1 .+ e .* cos.(χ))
x = r .* cos.(ϕ)
y = r .* sin.(ϕ)
orbit = vcat(x', y')
return orbit
endThis function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d_dt(v::AbstractVector, dt)
a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
b = (v[3:end] .- v[1:(end - 2)]) / 2
c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
return [a; b; c] / dt
endThis function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d2_dt2(v::AbstractVector, dt)
a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
return [a; b; c] / (dt^2)
endNow we define a function to compute the trace-free moment tensor from the orbit
function orbit2tensor(orbit, component, mass=1)
x = orbit[1, :]
y = orbit[2, :]
Ixx = x .^ 2
Iyy = y .^ 2
Ixy = x .* y
trace = Ixx .+ Iyy
if component[1] == 1 && component[2] == 1
tmp = Ixx .- trace ./ 3
elseif component[1] == 2 && component[2] == 2
tmp = Iyy .- trace ./ 3
else
tmp = Ixy
end
return mass .* tmp
end
function h_22_quadrupole_components(dt, orbit, component, mass=1)
mtensor = orbit2tensor(orbit, component, mass)
mtensor_ddot = d2_dt2(mtensor, dt)
return 2 * mtensor_ddot
end
function h_22_quadrupole(dt, orbit, mass=1)
h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
return h11, h12, h22
end
function h_22_strain_one_body(dt::T, orbit) where {T}
h11, h12, h22 = h_22_quadrupole(dt, orbit)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
h11 = h11_1 + h11_2
h12 = h12_1 + h12_2
h22 = h22_1 + h22_2
return h11, h12, h22
end
function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
# compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2
@assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"
h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
@assert mass_ratio ≤ 1 "mass_ratio must be <= 1"
@assert mass_ratio ≥ 0 "mass_ratio must be non-negative"
orbit = soln2orbit(soln, model_params)
if mass_ratio > 0
m₂ = inv(T(1) + mass_ratio)
m₁ = mass_ratio * m₂
orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
return waveform
endSimulating the True Model
RelativisticOrbitModel defines system of odes which describes motion of point like particle in Schwarzschild background, uses
where,
function RelativisticOrbitModel(u, (p, M, e), t)
χ, ϕ = u
numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
denom = sqrt((p - 2)^2 - 4 * e^2)
χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = numer / (M * (p^(3 / 2)) * denom)
return [χ̇, ϕ̇]
end
mass_ratio = 0.0 # test particle
u0 = Float64[π, 0.0] # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4) # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize) # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, eLet's simulate the true model and plot the results using OrdinaryDiffEq.jl
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)
axislegend(ax, [[l, s]], ["Waveform Data"])
fig
endDefining a Neural Network Model
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but in case you do use them, make sure to mark them as const.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,
const nn = Chain(
Base.Fix1(fast_activation, cos),
Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)((layer_1 = NamedTuple(), layer_2 = (weight = Float32[4.291766f-5; 4.0336286f-5; -0.00017002574; -0.00018820516; 9.914471f-5; -0.00011514237; 0.0001739122; -2.6746536f-6; 8.418071f-5; -0.0001323037; 0.00013989971; -5.9676222f-5; 0.00016231976; 6.337857f-5; 2.8782955f-5; -1.1641432f-6; 9.804708f-5; -2.8384264f-5; -9.728141f-5; -0.00012919508; -5.8910984f-5; -1.7123672f-5; 7.181182f-5; 0.00018050973; 6.5456414f-5; 2.776005f-5; -6.9834864f-6; 0.0001251161; -1.9841953f-5; -8.762088f-5; -0.00013825655; -0.00012769675;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[0.00016709378 5.943958f-5 -0.000125584 0.00011068538 2.5085407f-7 -0.00022533695 -3.3070242f-5 -0.00013278646 -6.2869934f-5 7.237159f-5 -4.9228114f-5 -3.4546232f-5 0.00013912578 -6.868752f-5 0.00017649919 -3.3151995f-5 0.00015230989 1.7892403f-6 -5.2680174f-5 -8.839365f-5 -0.00010093363 0.00018104338 -6.7947025f-5 6.3332074f-5 -1.4742323f-5 0.00021238126 -7.1441806f-5 0.00013374466 -3.8941926f-5 1.6497115f-5 0.00024616442 9.149688f-5; -6.380264f-5 -0.00010816673 -0.00012695968 -7.03473f-5 -0.00012934583 -3.1661028f-5 -7.690903f-6 0.00017416514 3.880423f-5 -8.643546f-5 -2.162484f-5 -0.00019663268 -0.00014956572 6.999314f-5 2.6329652f-5 -1.7025737f-5 -4.306246f-5 0.00011552764 -9.49372f-5 -3.3783475f-5 9.389012f-5 -0.00016424884 -3.8710918f-5 -0.00015981404 -8.10271f-5 -4.1507465f-5 -0.00010308333 0.00017035127 7.276564f-5 -0.00010376769 -2.2004897f-6 6.6400557f-6; -8.734108f-5 -4.3148648f-5 9.7402124f-5 -6.9887596f-5 -2.6929485f-5 -0.00012566095 0.00017066434 -7.093553f-5 0.00017620681 4.525913f-5 0.00011154674 8.762741f-5 8.359326f-5 0.0001531159 0.00018040337 -4.6330886f-5 0.00014199401 1.1984807f-5 0.00023686228 9.440023f-5 -0.00013847096 1.3814899f-5 0.00014976945 -0.00013158566 5.094211f-5 5.2307438f-5 -0.00024163109 8.193681f-5 8.517043f-5 0.00012392591 3.907351f-6 -6.9970854f-5; 6.304991f-5 -9.125443f-5 -0.00016810278 -0.00021844427 -0.00013422032 -0.00012678289 9.151099f-5 -9.2912924f-5 -3.3835415f-5 2.9290713f-5 -0.000101232254 -0.00011283908 0.00012376774 -3.83209f-5 -0.0001835285 -1.3972935f-5 9.1373564f-5 -9.717922f-5 6.345475f-5 0.0001334679 -8.604633f-5 -0.0001678554 -1.4805696f-6 -0.00013364309 0.00019585031 0.0001825164 0.00013455507 -9.4965595f-5 8.005323f-5 -2.0195916f-5 1.5344742f-5 2.374159f-5; -3.8599387f-6 -9.087641f-5 7.988888f-5 0.0001688097 0.00018498207 0.00011278298 -0.00018361147 -0.00016789944 1.7940962f-5 -0.00012585457 -2.8209703f-5 7.5923286f-5 3.0100953f-5 -7.0848855f-6 -4.5019024f-5 0.00010910429 -3.3253382f-5 -0.00014098879 0.0001422813 4.419175f-5 1.0994401f-5 -6.42556f-5 -3.5737464f-6 -7.9077f-5 3.833859f-5 -0.00012624277 0.0002399255 6.2296145f-5 1.6932065f-5 -3.2488497f-5 -2.285257f-5 -0.000318722; 0.00021831683 -7.2230614f-5 4.0765557f-5 5.6793044f-5 0.00013187801 5.6257963f-6 0.0001906538 1.5899332f-6 4.914826f-5 0.00014788985 2.3208186f-5 -9.285077f-5 5.265778f-5 -0.00013342152 0.0001137068 4.2553667f-5 6.939355f-5 0.00014699115 -8.881594f-5 -0.00013706766 8.9333174f-5 2.1654892f-5 8.9216446f-5 3.865063f-5 2.9951234f-5 9.600684f-5 2.5860263f-5 -0.0001438142 8.041925f-5 5.846544f-5 2.9003439f-5 0.00013694138; 6.494279f-5 5.415496f-5 -3.511195f-5 0.00011218975 7.443936f-5 -1.3725549f-5 -0.00018207617 9.560445f-5 -8.102748f-5 0.00017595537 4.023502f-5 0.00016202139 -0.00012930707 6.9055626f-5 0.00019563448 0.00012977472 3.1062147f-5 4.6481608f-5 -8.9568246f-5 -9.293715f-5 -2.5708325f-5 8.661203f-5 -2.9715102f-5 5.815873f-5 7.887582f-5 6.0721843f-5 9.281164f-5 -3.8501938f-5 2.7373791f-5 -1.2156092f-5 0.00011224206 -0.00014392719; -5.9013983f-5 6.071566f-5 7.16997f-5 -4.3273976f-5 0.00012460891 0.000106428 0.00016408178 0.000114949056 -2.8014429f-6 8.406015f-5 2.129109f-5 -3.0147152f-5 -6.2416846f-5 -4.8408874f-5 -0.00022442933 -6.346796f-6 2.1757396f-5 -0.00012902087 4.4110526f-5 -6.22187f-5 -3.9684755f-6 -5.862592f-5 0.0001018621 1.8056306f-5 0.00011720054 3.4708788f-5 3.6564816f-5 0.0001966071 1.8134946f-5 5.874445f-5 0.000115261544 -2.6161104f-6; 2.9057515f-5 -0.00014651613 -1.7324168f-5 -0.00014609947 -4.2264397f-5 4.7025496f-5 -0.00014101472 0.00015409988 -0.00017364143 0.0002025909 -0.000115226794 6.277288f-5 6.891528f-5 0.00011355854 -0.000130974 -1.4613196f-5 -6.377187f-6 5.463505f-5 2.2170863f-5 -0.0003167837 -2.69183f-5 -5.9107846f-5 -2.2189004f-5 9.056244f-5 -0.0001803699 -0.00014796092 -0.00016557403 -6.029372f-5 -0.00021272726 0.00011368013 -5.6663648f-5 6.824418f-5; 5.3611864f-5 8.659691f-5 7.1341135f-5 0.00016318347 -4.510803f-5 -2.7305367f-5 -9.011202f-5 2.5696943f-5 7.764466f-5 -0.00012421627 9.009196f-5 -1.4738791f-5 1.3693545f-5 0.00010826626 -8.0323756f-5 -0.00027006707 -0.00011728386 -7.610609f-5 0.000107259424 -0.00032241168 0.00013390693 -0.00014028994 0.00018527095 1.2839339f-5 -0.00017448945 -6.3982414f-5 -2.2863285f-5 0.00022024411 3.557995f-5 -0.00011053077 -1.315217f-5 0.00014102491; 2.8888335f-5 0.00010536174 -8.1859916f-8 7.307947f-5 0.00012046821 -0.000107365864 1.9386282f-6 -7.128046f-5 -1.7794344f-5 -9.615222f-5 0.00011509687 3.3213244f-5 -7.218303f-6 -6.920479f-5 0.0001478752 -2.5687496f-5 -5.518788f-5 -0.000119006225 -0.00011305291 -9.1638336f-5 1.5679472f-5 2.3024873f-5 9.654777f-5 5.802373f-5 -0.00016158845 0.00022313451 -6.274613f-5 -0.00019852298 -0.00013349997 3.2253972f-6 -0.00013769841 9.1507536f-5; -8.1631726f-5 -5.328689f-6 -3.6337224f-5 -1.7548557f-5 -0.00023725425 0.00011790387 -6.282367f-5 -0.00024689158 9.88756f-5 -0.00021019255 -0.00021156635 3.642304f-5 -0.00019428619 -3.9362807f-5 0.00013619864 -3.0767537f-6 0.00012040138 8.803214f-5 4.8175785f-5 0.0002672965 4.823769f-5 -0.00022525471 -1.555151f-5 0.00013014473 -3.9866794f-5 -6.4339234f-5 0.000112095986 -3.589734f-5 -1.004195f-5 -0.00014645293 -0.000112587106 -5.167916f-6; 7.441632f-5 -1.20015875f-5 -0.00012361491 -0.00013848675 -0.00010665353 -6.8075235f-5 3.826434f-5 -2.6497035f-5 8.221186f-5 -0.00013392938 0.00013977443 4.7469126f-5 -2.4157293f-5 -7.040143f-5 4.5117707f-5 -9.236934f-6 8.702267f-5 5.9577327f-5 0.000107850035 3.0785325f-6 -8.849845f-5 5.439983f-5 5.3579664f-5 -4.6870588f-5 7.0752416f-5 0.00011719535 1.7216981f-6 -0.00015745091 -6.336137f-6 -0.000119917444 -1.727127f-5 -0.00012455361; -1.473509f-5 -9.6100615f-5 -4.0404875f-6 0.000113461036 2.5581656f-5 -3.261563f-5 -6.377477f-5 -1.8731174f-5 -6.463884f-5 0.0002214911 -9.83414f-5 4.1220927f-5 1.323565f-5 0.00017280957 -2.2875145f-5 4.1685587f-5 -5.428307f-5 0.00016856926 -8.748286f-5 5.1658786f-5 -0.00011796554 -1.7635306f-5 8.261432f-6 -4.017607f-5 8.3723666f-5 4.467811f-5 4.3449752f-5 0.00018410209 -0.00018474282 0.00017493343 -0.00010298561 2.3828869f-5; 9.4719784f-5 3.164673f-5 -0.00019316 8.098401f-5 0.0001305834 0.0001354024 -6.66189f-5 -0.00019881432 -0.00010571176 0.00010395172 3.4819193f-5 7.825655f-6 0.00013030895 -2.2937707f-5 3.354336f-5 -1.8288993f-5 -1.5901296f-5 8.745295f-5 0.0001243969 9.3821785f-5 4.834072f-7 0.0001102105 7.3139076f-5 -0.00017562235 -6.74234f-5 -7.500568f-5 0.000111349604 7.8806115f-6 -8.896412f-5 0.00018795439 0.00015417523 -0.00019431017; 1.6261163f-5 6.449206f-5 7.799482f-5 1.468102f-5 -0.00016035316 4.5927438f-5 -2.3197934f-5 0.00018712821 -6.511111f-5 -3.2943757f-5 8.62925f-5 0.00012645936 0.0001178334 -2.4440344f-5 6.4095744f-5 0.00013758846 5.4093376f-5 -3.0158424f-6 -3.568257f-5 -4.5749548f-5 2.3296603f-5 -0.00015279793 3.6964022f-5 -7.590029f-5 -4.1791573f-5 7.0002054f-5 -2.3880797f-5 3.0146979f-5 0.00011600023 3.281704f-5 2.2449693f-5 -0.00012976883; -4.093586f-5 1.304245f-5 7.017192f-5 0.00011045048 -2.1560321f-5 -0.00010444115 8.480285f-5 -0.00010302977 0.00021304484 -0.0001948587 -1.7404816f-5 0.00019952124 5.408219f-5 0.00014532059 -0.00010446764 -7.055963f-5 -0.00015172288 7.556856f-5 -9.425338f-5 2.0020434f-5 -1.5355386f-5 -5.4116732f-5 -6.627033f-5 -9.57598f-5 0.00018751778 0.00011825739 -4.0789757f-5 -9.038383f-5 0.00017516206 -0.00013861641 -0.00011742041 -4.7815814f-5; 5.4661272f-5 -0.00013160966 3.753099f-5 -0.00014420837 -8.235453f-5 -6.334927f-5 -1.024179f-5 -0.00017063139 -0.00028889452 0.000111812704 -9.885703f-6 -7.5026455f-6 8.961852f-5 9.544166f-6 7.1710965f-5 -5.795931f-5 -0.00015162549 -0.00021072416 0.00030349335 -3.3329798f-5 0.00021990028 5.8899037f-5 7.284077f-5 4.7172758f-5 8.5204716f-5 2.4178611f-5 -2.0337562f-5 -9.610162f-6 -4.3512387f-5 -0.00013571943 -3.8513164f-5 -9.582493f-7; 0.00010358071 -2.137034f-5 -1.1277798f-5 4.914607f-5 -2.864778f-5 -0.0001765559 -0.000102908525 3.695357f-5 8.907683f-5 -0.00012299322 -0.00014463045 1.0951805f-5 -4.6011664f-5 5.2679265f-5 -0.00014577719 4.958004f-5 4.7711976f-5 -4.188045f-5 1.3692851f-5 0.00010595076 -1.015605f-5 8.838605f-5 9.638356f-5 -3.9797646f-6 4.7784528f-5 -1.3742922f-5 -0.00011346548 -6.0144863f-5 3.16096f-5 0.00016172045 1.6475788f-5 7.188933f-5; 3.5225417f-5 -5.2825737f-5 -0.00019716543 2.0673182f-5 -8.9194335f-5 -7.434729f-5 -0.00014200888 -4.318438f-5 6.5449767f-6 -7.538079f-5 1.863344f-5 -0.000109523244 -0.00010314221 1.9459208f-5 -5.2337946f-5 -2.593018f-5 0.00013726862 0.00023969129 3.140188f-5 2.55631f-5 -0.00012730823 0.00012282278 1.2008254f-5 0.00011319556 -1.2067248f-5 0.00014501798 0.00010206108 -3.810431f-5 2.0168634f-5 2.6939988f-5 6.502203f-5 -6.859928f-5; -2.7437876f-5 5.3153617f-5 -1.0648718f-5 0.000118514145 -0.00021296642 -6.759802f-5 -1.1421586f-5 -4.480166f-5 9.2527334f-5 5.8089474f-5 -1.8569355f-5 -9.59622f-5 -0.00011766607 -2.176205f-5 3.261842f-6 7.4366624f-5 0.000115520364 8.819783f-5 -0.00010883905 1.42024965f-5 0.00010329148 0.00010518614 8.328908f-5 7.675138f-5 -6.8523266f-5 4.4015276f-5 8.070748f-5 -5.451344f-5 -2.9719033f-5 -1.9864508f-5 -6.155364f-5 -4.8927843f-5; -2.148804f-5 -0.00015552554 0.00019622775 -0.0001645183 -6.398155f-5 0.00014655653 1.2670126f-5 -2.3476669f-5 -3.8093465f-5 0.00010613126 0.00011675587 5.4232132f-5 6.176833f-5 4.226823f-6 -3.3681466f-5 -4.7440084f-5 6.321172f-5 -1.7991555f-5 -0.00010824258 -6.9102985f-5 -0.0002052105 -5.988779f-5 8.361703f-5 8.997843f-5 -0.00012944113 5.509434f-5 -0.00014295461 5.734231f-5 -0.00020242367 -6.93616f-5 9.266659f-6 7.139597f-5; -1.3971383f-5 -7.2293056f-5 -0.000108572334 -3.5187742f-5 -6.378012f-5 -3.2922493f-5 8.553494f-5 -4.897797f-5 -0.00012794524 -3.0311283f-5 0.00017795716 -6.727271f-5 0.0001849002 -5.036112f-5 0.00013655484 0.00016317508 -0.000107373766 7.3991187f-6 -4.3108455f-5 1.9056964f-5 -1.5228257f-5 -9.16754f-5 5.6429788f-5 7.1872935f-5 -0.00019103808 -0.00015181753 5.7528552f-5 -6.501952f-5 -3.7895654f-5 4.7287303f-5 2.9785235f-5 -8.229831f-5; 2.1349231f-5 0.000195932 -0.00017306703 -0.00023009094 -0.00014081084 -7.521425f-5 0.00012548723 2.9878598f-5 -0.00011976964 -4.132819f-5 -5.354409f-6 5.057319f-5 0.00023267775 -0.00015178462 -0.00015484584 -5.8396985f-5 8.730176f-5 0.00011710177 -0.00010621692 0.00011553716 6.312064f-6 -0.00013035726 -6.599313f-5 -0.00019547004 1.9185312f-5 0.000112728245 6.0281327f-5 -2.2231012f-5 3.0051103f-6 9.606912f-5 5.8400277f-5 9.3773764f-5; 8.6541564f-5 -9.68035f-6 -1.2450808f-5 -6.5123213f-6 0.00018318454 0.00014933189 -0.00013400185 8.8724504f-5 8.371371f-6 -0.00010720692 2.7345188f-5 -3.6904214f-5 -2.6989558f-6 -1.643402f-5 4.875023f-6 7.802861f-5 2.7787706f-5 -1.6961541f-5 9.7173994f-5 3.22688f-5 0.00018259695 -1.2981668f-5 -0.00012564263 -0.00016039808 -0.00011260637 1.5471696f-5 -8.8308654f-5 -1.6827737f-5 -0.00016992794 0.00017024814 -6.6496475f-5 0.00015882483; -0.0001602626 -5.96831f-5 -0.00016436064 8.991112f-6 -0.0001578517 -0.0002833444 0.00030491187 -5.427934f-6 0.00015092839 -0.00021770857 -9.7407006f-5 4.288943f-5 -0.0001982686 0.00010903623 -0.00013135275 6.186347f-6 9.6789365f-5 5.48918f-5 -7.10804f-5 -2.9612323f-5 -5.9283f-5 -5.889177f-5 -3.4641136f-5 -0.00015921686 -8.839221f-5 5.6721095f-5 6.472484f-5 -7.672062f-5 -7.061808f-5 3.355896f-5 -6.81827f-5 1.3792093f-5; 4.310532f-5 7.948977f-5 0.00012805923 5.579977f-5 3.0407426f-5 -0.00023283837 3.2931628f-5 -2.7427517f-5 2.2440176f-5 7.3054274f-5 -8.425905f-5 3.128448f-5 3.399935f-5 -0.00015383068 1.2227814f-5 -0.00015677513 1.4033621f-5 1.29543205f-5 -1.8592638f-6 -5.1266074f-5 0.00011763964 3.9641705f-5 -0.00014629164 1.584288f-5 -7.15942f-5 -8.90377f-5 0.000115128125 -0.00014289087 0.00013446901 -1.8171504f-5 -0.0001438809 -1.9623083f-5; -8.6575776f-5 5.5790315f-5 0.00014780869 -3.2969656f-5 3.181776f-5 -3.2353662f-5 6.687225f-5 -6.0938324f-5 -6.6896115f-5 0.00013579708 -1.7311946f-5 -7.056366f-5 -2.1000233f-5 -7.416536f-5 0.00017029239 -9.9366574f-5 4.6151636f-5 -5.7094265f-5 -4.6530433f-5 3.6467503f-5 -0.00015506773 -3.3918382f-6 0.00014253873 -7.515133f-5 6.205894f-5 -0.00014082789 -6.554759f-5 -0.00010976345 -0.00010992798 8.9955385f-5 -7.531381f-5 -1.8105451f-5; 1.9929046f-5 -6.8084504f-5 -2.8260609f-5 8.5991895f-5 -0.00011855153 6.4452695f-5 0.000120892844 6.7206856f-5 -6.486055f-5 6.7450565f-5 1.6506141f-5 1.3778262f-5 -9.043382f-5 4.8412858f-5 -8.10092f-5 -3.5717672f-5 9.9571866f-5 -0.00029601148 -6.4562824f-5 0.00014680374 5.0034305f-6 -1.0220719f-5 5.3278658f-5 0.00012678752 -6.2187064f-5 0.00010244394 2.6606624f-5 -9.789879f-7 -0.00011677917 -2.5351532f-5 0.00025653373 8.154276f-6; 8.44419f-5 -7.973504f-5 5.7915368f-5 -4.61219f-6 -7.532752f-5 0.00013940744 -5.8291644f-5 0.0001910532 6.8249574f-5 -3.2246717f-5 -0.00013768465 0.00013420571 -0.00016252697 -4.6309606f-6 -6.6629604f-5 -7.6992095f-5 0.00010632135 8.298316f-5 -5.38392f-5 -0.00011916148 -0.00019887375 9.197743f-6 0.00024060527 -5.855533f-5 1.906734f-5 -8.12116f-5 4.025475f-7 3.1175125f-5 2.7312f-5 -0.0001141006 0.00012554644 6.977738f-5; -2.6971748f-5 -5.163941f-5 -7.7096774f-5 3.9768234f-5 0.00015971037 0.0002278611 7.82455f-5 -5.3368905f-5 4.4857516f-5 -7.7338475f-5 2.1246005f-5 -7.9736354f-5 0.00011453444 0.00012227354 9.953128f-5 -0.00014265177 7.150085f-5 0.00018899415 0.00013138738 -0.00026135353 4.474136f-5 -9.6745905f-5 -7.381047f-5 0.00014843774 8.72857f-5 -0.0001428445 -0.0001505355 -0.00016254786 3.9965507f-5 0.00011918043 -1.310946f-5 -6.1469786f-5; 5.2314917f-5 8.028539f-5 3.2808573f-5 -0.0001223498 7.4923024f-5 -2.3702309f-5 1.130243f-5 0.00018243349 0.000107935426 4.1080486f-5 7.586158f-5 6.495992f-5 -9.162678f-5 -7.4748925f-5 -4.2615673f-5 1.4565359f-5 -0.00016958339 -2.144165f-6 1.0525446f-6 -1.2455414f-5 5.825797f-5 3.535567f-7 -3.9860133f-5 -0.00011903161 6.4786815f-5 3.820141f-5 -0.00023413751 -0.00011721813 3.9100083f-5 -6.1126107f-6 -7.519985f-5 -6.978796f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[-8.056757f-5 0.00013594593 -6.8021254f-5 -4.8739563f-5 9.5585514f-5 -0.000120604855 1.4801412f-5 6.6949356f-6 -5.938069f-5 -0.00011380825 0.00023125451 7.2983814f-5 7.085532f-5 -6.231535f-5 5.368959f-5 -0.000120469056 0.00013754178 -4.9648126f-5 -3.285739f-5 -7.352568f-5 5.5835056f-7 -1.1486374f-5 6.283166f-5 -9.407237f-5 -2.2166849f-5 6.621529f-5 -0.00014865736 -8.429156f-5 0.00019609625 3.3291246f-5 0.00018594282 -0.00011633142; -3.190726f-5 5.6282588f-5 0.00012105152 5.1846864f-6 5.265587f-5 -0.00019868981 -1.22407755f-5 0.00018520709 1.8620993f-5 -0.00019435983 -9.840687f-5 0.0002296274 8.3609804f-5 -0.00016604869 8.4821106f-5 0.00020493339 0.0001570654 -3.0395f-5 0.00026081994 -7.347154f-5 -5.4198474f-5 -3.361748f-5 -0.00017828045 -0.0001656381 0.00016530116 -0.0002377856 -6.994287f-5 -3.857728f-5 0.00010810546 2.7691984f-5 -7.781818f-5 1.6255712f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))Similar to most deep learning frameworks, Lux defaults to using Float32. However, in this case we need Float64
const params = ComponentArray(f64(ps))
const nn_model = StatefulLuxLayer(nn, nothing, st)StatefulLuxLayer{Val{true}()}(
Chain(
layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
layer_2 = Dense(1 => 32, cos), # 64 parameters
layer_3 = Dense(32 => 32, cos), # 1_056 parameters
layer_4 = Dense(32 => 2), # 66 parameters
),
) # Total: 1_186 parameters,
# plus 0 states.Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses
where,
function ODE_model(u, nn_params, t)
χ, ϕ = u
p, M, e = ode_model_params
# In this example we know that `st` is am empty NamedTuple hence we can safely ignore
# it, however, in general, we should use `st` to store the state of the neural network.
y = 1 .+ nn_model([first(u)], nn_params)
numer = (1 + e * cos(χ))^2
denom = M * (p^(3 / 2))
χ̇ = (numer / denom) * y[1]
ϕ̇ = (numer / denom) * y[2]
return [χ̇, ϕ̇]
endLet us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
axislegend(
ax,
[[l1, s1], [l2, s2]],
["Waveform Data", "Waveform Neural Net (Untrained)"];
position=:lb,
)
fig
endSetting Up for Training the Neural Network
Next, we define the objective (loss) function to be minimized when training the neural differential equations.
const mseloss = MSELoss()
function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
return mseloss(pred_waveform, waveform)
endWarmup the loss function
loss(params)0.0007232226422724241Now let us define a callback function to store the loss over time
const losses = Float64[]
function callback(θ, l)
push!(losses, l)
@printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
return false
endTraining the Neural Network
Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
optprob,
BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
callback,
maxiters=1000,
)retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [4.291766163071503e-5; 4.033628647444361e-5; -0.00017002574168127241; -0.000188205158337539; 9.914470865596873e-5; -0.0001151423712143178; 0.00017391219444106505; -2.6746536150307435e-6; 8.418071229241358e-5; -0.00013230369950149412; 0.00013989971193931671; -5.9676222008385385e-5; 0.00016231976042017668; 6.337856757455601e-5; 2.8782955268930544e-5; -1.1641432138266774e-6; 9.80470795183544e-5; -2.8384263714536246e-5; -9.728140867078656e-5; -0.00012919507571496918; -5.8910984080295344e-5; -1.712367156866087e-5; 7.18118171789089e-5; 0.00018050972721530076; 6.545641372202637e-5; 2.7760050215842903e-5; -6.983486400714133e-6; 0.00012511610111673612; -1.984195296242969e-5; -8.76208796397241e-5; -0.0001382565533273533; -0.0001276967523154436;;], bias = [7.593106070885216e-17, 4.578385686355334e-17, -2.071799658741397e-16, -4.654179405536354e-16, 3.05532413214339e-16, -1.626862633733165e-16, 4.87509362547652e-18, -5.491310781828898e-19, 2.687077619910629e-17, -3.333173745806142e-16, 2.1071229453839871e-16, -3.7710348145813675e-17, 2.0373131887031679e-16, -1.2642589738864591e-17, 5.695462138015738e-17, -1.0835893434796178e-18, 9.349738326578313e-17, -8.391918819326745e-18, -3.984673756119903e-17, 8.823091444971888e-18, 1.0766804094224972e-17, -4.0850618950214833e-17, 6.658712724132609e-17, 1.0949079769605975e-16, 1.1862460549979562e-16, 3.6125166821835886e-17, 7.435798864070156e-19, 1.531743720597479e-16, -2.4960590755671984e-17, -1.6338254305986463e-16, -3.4835296198894183e-16, 6.095149858964609e-17]), layer_3 = (weight = [0.00016709622524412773 5.944202586160488e-5 -0.00012558155630051947 0.00011068782861308748 2.5330117035908906e-7 -0.00022533449863350498 -3.306779525394519e-5 -0.0001327840121908615 -6.286748738294454e-5 7.237404007663134e-5 -4.922566712286373e-5 -3.4543785101600794e-5 0.0001391282299786095 -6.868507388081149e-5 0.00017650163737325503 -3.31495479136741e-5 0.00015231233793509938 1.7916873852272848e-6 -5.267772702301099e-5 -8.83912047234998e-5 -0.00010093117942833974 0.00018104582214927855 -6.794457784654406e-5 6.333452130399491e-5 -1.473987630252492e-5 0.00021238370628761638 -7.143935853182698e-5 0.0001337471063501507 -3.8939478496652255e-5 1.649956231030134e-5 0.0002461668651570844 9.149932394701265e-5; -6.380712565976191e-5 -0.000108171214922434 -0.000126964169428747 -7.035178405005271e-5 -0.0001293503197283461 -3.166551377749276e-5 -7.695388920026148e-6 0.00017416064965993834 3.879974556378853e-5 -8.643994496910423e-5 -2.1629327054910717e-5 -0.00019663716793535951 -0.0001495702059381379 6.998865446997193e-5 2.6325165346975867e-5 -1.7030223350844185e-5 -4.306694663488109e-5 0.00011552315427445661 -9.494168316256078e-5 -3.3787961545248385e-5 9.388563503501245e-5 -0.0001642533212342415 -3.871540375063644e-5 -0.00015981852320494326 -8.103158743130616e-5 -4.1511950801463444e-5 -0.00010308781620358375 0.00017034678185350402 7.276115084289368e-5 -0.00010377217822484753 -2.2049758607364136e-6 6.635569572745629e-6; -8.733679339565503e-5 -4.314435852501677e-5 9.740641337772602e-5 -6.988330630633257e-5 -2.692519580814874e-5 -0.00012565666463270684 0.00017066862722346141 -7.0931240653581e-5 0.00017621110248801329 4.526341956878517e-5 0.00011155102777482158 8.763170274396952e-5 8.359754804475364e-5 0.00015312019616322122 0.00018040765656115461 -4.6326596813005405e-5 0.00014199829832140963 1.1989096609621566e-5 0.0002368665701658714 9.440452061315617e-5 -0.0001384666719524788 1.3819188214992682e-5 0.0001497737358738978 -0.00013158137505455686 5.0946400850810617e-5 5.231172702115945e-5 -0.00024162679971704177 8.1941096258366e-5 8.517472099702842e-5 0.00012393020381641523 3.911640423012213e-6 -6.996656508928492e-5; 6.304720130554013e-5 -9.125713865234922e-5 -0.00016810549010715312 -0.00021844698227668405 -0.00013422302896816897 -0.00012678560330260858 9.150828094261952e-5 -9.291563498284284e-5 -3.3838126060635966e-5 2.9288002164494866e-5 -0.00010123496491888293 -0.00011284179078377086 0.0001237650239401768 -3.832361204860927e-5 -0.00018353120418926246 -1.3975646712266266e-5 9.137085265518463e-5 -9.718192910970666e-5 6.345203558720002e-5 0.00013346518511224523 -8.604904457190316e-5 -0.0001678581075482961 -1.4832808718176187e-6 -0.0001336457981467853 0.00019584760132982222 0.00018251369595798184 0.0001345523586989419 -9.496830631706963e-5 8.005051762136117e-5 -2.0198626831023606e-5 1.534203050522803e-5 2.3738878493704402e-5; -3.860499783073537e-6 -9.087697335463977e-5 7.988831848660023e-5 0.00016880913526573808 0.00018498150811818554 0.00011278242081924526 -0.00018361202901406653 -0.00016790000599986924 1.7940400629854562e-5 -0.0001258551282267153 -2.8210263645189606e-5 7.592272513601006e-5 3.0100392086286605e-5 -7.0854465780172584e-6 -4.501958474112714e-5 0.00010910372575334218 -3.325394291142814e-5 -0.00014098935289486348 0.00014228073186994648 4.419118986287403e-5 1.0993840168779794e-5 -6.425615802119537e-5 -3.574307493650319e-6 -7.907756016767048e-5 3.833802757024788e-5 -0.00012624332966926547 0.00023992494161528485 6.229558412040812e-5 1.693150361414485e-5 -3.248905786723745e-5 -2.2853130667965735e-5 -0.0003187225746846827; 0.00021832201056568265 -7.222542879717075e-5 4.0770742285851916e-5 5.6798228994321864e-5 0.00013188319755171333 5.630981573978482e-6 0.00019065897978757707 1.5951183890194486e-6 4.915344581276543e-5 0.00014789503367535313 2.3213370892672253e-5 -9.284558726321882e-5 5.26629639501676e-5 -0.00013341633419586396 0.00011371198817542305 4.255885252533173e-5 6.939873602858685e-5 0.0001469963364947077 -8.88107558999866e-5 -0.00013706247662742006 8.933835962707216e-5 2.1660077633355896e-5 8.9221631439065e-5 3.865581422027942e-5 2.9956419181324514e-5 9.601202802187326e-5 2.5865448348127874e-5 -0.00014380902101367854 8.042443843478894e-5 5.8470626041882533e-5 2.9008623838645625e-5 0.00013694656555924314; 6.494635354165774e-5 5.4158524017751865e-5 -3.510838459927134e-5 0.00011219331545422488 7.444292614372309e-5 -1.3721985032417082e-5 -0.00018207260391075794 9.560801745529353e-5 -8.102391539005382e-5 0.0001759589348700354 4.023858391581705e-5 0.00016202495216994202 -0.00012930350303515972 6.9059190446479e-5 0.0001956380428061538 0.0001297782863736811 3.106571139205136e-5 4.6485171788576724e-5 -8.956468182014396e-5 -9.293358843871389e-5 -2.570476053669917e-5 8.661559595630537e-5 -2.971153762637235e-5 5.81622943179033e-5 7.887938309131694e-5 6.0725406820624855e-5 9.28152066087077e-5 -3.849837330500473e-5 2.737735567343908e-5 -1.2152528233222009e-5 0.00011224562231353208 -0.00014392362309019035; -5.9010924183188676e-5 6.0718720028286354e-5 7.170276030114695e-5 -4.3270917836035274e-5 0.00012461196701657842 0.00010643105878876992 0.00016408483742884073 0.00011495211464969372 -2.7983845077229867e-6 8.406320802783642e-5 2.129414924161817e-5 -3.0144093493823474e-5 -6.24137876959601e-5 -4.840581578284971e-5 -0.0002244262698743321 -6.3437376458949776e-6 2.1760454451184693e-5 -0.00012901781638467833 4.4113584130871225e-5 -6.22156415422434e-5 -3.965417186214545e-6 -5.862286098383466e-5 0.00010186515527927665 1.8059363987724656e-5 0.0001172036015257376 3.471184636034858e-5 3.656787402613015e-5 0.00019661015784982824 1.813800435664024e-5 5.8747506677509334e-5 0.00011526460247728864 -2.613051997504724e-6; 2.9052980089750483e-5 -0.00014652066835858487 -1.7328703115463978e-5 -0.0001461040033698012 -4.226893231112789e-5 4.7020961055972096e-5 -0.00014101925859879766 0.00015409534301418016 -0.0001736459622132008 0.0002025863699875578 -0.00011523132940942195 6.276834522416542e-5 6.891074502188075e-5 0.00011355400550221819 -0.0001309785389546745 -1.4617731330670464e-5 -6.381722477965131e-6 5.4630514050789854e-5 2.2166328044693393e-5 -0.0003167882337672842 -2.6922835395245863e-5 -5.9112381286265425e-5 -2.21935393435791e-5 9.055790490517677e-5 -0.00018037443312075152 -0.0001479654552620345 -0.0001655785695855026 -6.0298256875969244e-5 -0.00021273179533616446 0.00011367559402991422 -5.6668183396820246e-5 6.823964707132176e-5; 5.36112078070462e-5 8.659625597189256e-5 7.134047915585062e-5 0.00016318281847176933 -4.510868678793063e-5 -2.7306022854676992e-5 -9.011267833287148e-5 2.5696287014363718e-5 7.764400503489377e-5 -0.00012421692521187065 9.009130586761502e-5 -1.4739447248791137e-5 1.3692888717713228e-5 0.00010826560024994549 -8.032441259005052e-5 -0.00027006773019551254 -0.0001172845164883578 -7.610674368792158e-5 0.00010725876778721119 -0.0003224123371519761 0.00013390627133318337 -0.00014029059242333242 0.00018527029748198722 1.2838683108357105e-5 -0.00017449011069132816 -6.398307017392425e-5 -2.2863941591856606e-5 0.00022024345377118698 3.5579293001285295e-5 -0.00011053142374024526 -1.3152826465587341e-5 0.00014102425335565035; 2.8886961023533116e-5 0.00010536036419260806 -8.323367472951943e-8 7.307809713575172e-5 0.00012046683835841635 -0.00010736723785476689 1.9372544434195474e-6 -7.128183502892544e-5 -1.7795717634281293e-5 -9.615359018744644e-5 0.00011509549361698676 3.321187026493826e-5 -7.219676684600354e-6 -6.920616439266466e-5 0.00014787382091206952 -2.5688869267027373e-5 -5.518925396564507e-5 -0.00011900759860828997 -0.00011305428639708184 -9.163970977488685e-5 1.567809844026447e-5 2.3023499612905666e-5 9.654639283006724e-5 5.802235458441069e-5 -0.00016158982767911954 0.00022313313960374787 -6.274750054536898e-5 -0.00019852434931045223 -0.0001335013425555316 3.2240234636352316e-6 -0.00013769978837763563 9.150616236774456e-5; -8.16348922170897e-5 -5.331855140803665e-6 -3.634038987447634e-5 -1.755172272614125e-5 -0.00023725741741593224 0.00011790070745981422 -6.282683282270086e-5 -0.00024689474341806344 9.887243393229737e-5 -0.00021019571543084823 -0.00021156951809181055 3.6419873845326324e-5 -0.00019428935158876163 -3.936597315708074e-5 0.0001361954698723532 -3.079919883869219e-6 0.00012039821629066699 8.802897243646674e-5 4.817261904015869e-5 0.0002672933394706762 4.8234522887533984e-5 -0.00022525787901480032 -1.5554676087892282e-5 0.0001301415675471655 -3.986996055112091e-5 -6.434240024185043e-5 0.00011209281981343098 -3.5900507305009326e-5 -1.0045116346129607e-5 -0.0001464561002875817 -0.00011259027198767498 -5.171081943366731e-6; 7.441515252109673e-5 -1.2002753938149636e-5 -0.0001236160774354343 -0.00013848791652009427 -0.00010665469973930731 -6.807640103972959e-5 3.826317468309043e-5 -2.6498201368968127e-5 8.22106935444686e-5 -0.00013393054770888108 0.00013977326805586724 4.746795955695334e-5 -2.4158458946039384e-5 -7.040259744856368e-5 4.5116541041052156e-5 -9.238100138036659e-6 8.702150578726434e-5 5.957616074708308e-5 0.00010784886897424582 3.0773660270611444e-6 -8.849961388507738e-5 5.4398662068406104e-5 5.3578497936243156e-5 -4.6871753984593706e-5 7.075124947702662e-5 0.00011719418169347587 1.7205316950545627e-6 -0.00015745207430911603 -6.3373036212671785e-6 -0.00011991861039889697 -1.727243627130741e-5 -0.00012455477783107524; -1.4733537305397197e-5 -9.60990624531001e-5 -4.038934626415208e-6 0.00011346258853739143 2.5583208675416152e-5 -3.261407672249515e-5 -6.377321968609194e-5 -1.8729621465858765e-5 -6.463729058446788e-5 0.00022149264570037815 -9.833984639550682e-5 4.122247938943779e-5 1.3237202815099207e-5 0.00017281111960477854 -2.2873592193314046e-5 4.168714023257536e-5 -5.428151833762966e-5 0.00016857080792255597 -8.748130911550096e-5 5.1660338990073896e-5 -0.00011796398447182274 -1.7633753109798664e-5 8.262985297226332e-6 -4.017451700344524e-5 8.372521862333358e-5 4.467966156375487e-5 4.345130529549832e-5 0.00018410363865262904 -0.0001847412682632738 0.00017493498618427206 -0.00010298405376085174 2.383042139120909e-5; 9.472182604263462e-5 3.164877229032978e-5 -0.0001931579562205729 8.09860549798285e-5 0.00013058544312848108 0.00013540444082358957 -6.661685880496258e-5 -0.0001988122775055184 -0.0001057097205029563 0.00010395376479031337 3.4821235329254155e-5 7.827697243602466e-6 0.00013031099400726316 -2.2935665458089054e-5 3.3545403437577126e-5 -1.8286950636999198e-5 -1.58992541964411e-5 8.745499247930009e-5 0.0001243989437517 9.382382735388229e-5 4.8544911856072e-7 0.00011021254353004539 7.314111748023597e-5 -0.00017562030812948257 -6.742136143834565e-5 -7.500363686808388e-5 0.00011135164562624925 7.882653460949087e-6 -8.896208133769199e-5 0.00018795642712397974 0.00015417727003117738 -0.0001943081250482994; 1.62631213132245e-5 6.449401643096478e-5 7.799677495113693e-5 1.4682977951268487e-5 -0.00016035119825772524 4.592939602645307e-5 -2.3195975711312727e-5 0.00018713017287866324 -6.510915319466373e-5 -3.294179897437882e-5 8.629445891329808e-5 0.00012646131743878944 0.00011783535682098293 -2.443838549168459e-5 6.409770229567836e-5 0.00013759041847863346 5.409533422409054e-5 -3.01388423593129e-6 -3.5680611301525774e-5 -4.574758987672654e-5 2.3298561234533632e-5 -0.00015279597483838813 3.6965980132499396e-5 -7.589832877074482e-5 -4.178961445375507e-5 7.000401182069509e-5 -2.3878838885299683e-5 3.0148937173919225e-5 0.00011600218657606371 3.281899715859063e-5 2.2451650682169057e-5 -0.00012976687049757693; -4.093627612626668e-5 1.3042035244616893e-5 7.017150383155668e-5 0.00011045006285833158 -2.1560735693745396e-5 -0.0001044415617786195 8.480243595961375e-5 -0.00010303018607253723 0.0002130444267107318 -0.00019485911579856927 -1.7405230754481333e-5 0.00019952082257197705 5.408177494366494e-5 0.00014532017703579282 -0.00010446805354029442 -7.056004651411019e-5 -0.0001517232954926641 7.556814707761129e-5 -9.425379503107085e-5 2.002001965713399e-5 -1.535580085936003e-5 -5.4117146281605085e-5 -6.62707458092689e-5 -9.57602165714659e-5 0.00018751736970563896 0.00011825697620344778 -4.0790171259395304e-5 -9.038424432099168e-5 0.0001751616440754519 -0.00013861682200393117 -0.00011742082242442552 -4.7816228833359804e-5; 5.465959822676424e-5 -0.00013161133409094727 3.752931645058914e-5 -0.00014421004760159872 -8.235620060086469e-5 -6.335094092728603e-5 -1.0243463457383146e-5 -0.00017063306602732096 -0.00028889619004995 0.000111811030201262 -9.887376272761408e-6 -7.504319164679378e-6 8.961684927263046e-5 9.542491916266059e-6 7.170929149641879e-5 -5.796098428628228e-5 -0.00015162715879351132 -0.00021072583347641637 0.0003034916733849025 -3.3331471222572345e-5 0.00021989860380917226 5.889736332382272e-5 7.283909491879393e-5 4.717108442483558e-5 8.5203042381084e-5 2.4176937422354128e-5 -2.0339235725359093e-5 -9.61183575791441e-6 -4.351406112628788e-5 -0.00013572110054165456 -3.851483796106365e-5 -9.599229361807768e-7; 0.00010358082399145961 -2.1370223845925367e-5 -1.1277681193652855e-5 4.914618729638068e-5 -2.8647663936472998e-5 -0.00017655578231530982 -0.00010290840837730877 3.6953686861333316e-5 8.907694463526141e-5 -0.00012299310729465247 -0.000144630335495796 1.0951921825134406e-5 -4.6011547607824406e-5 5.2679381399892925e-5 -0.00014577707007153628 4.9580158150301194e-5 4.7712092412647304e-5 -4.188033163406622e-5 1.3692967871273853e-5 0.00010595087987278794 -1.0155933165852697e-5 8.838616521937108e-5 9.63836795501529e-5 -3.979647869320786e-6 4.778464462399733e-5 -1.3742805646662283e-5 -0.00011346536449015804 -6.014474602740983e-5 3.160971791042281e-5 0.00016172057111655597 1.6475904333882112e-5 7.188944472839332e-5; 3.522514407700438e-5 -5.282601023031822e-5 -0.00019716570068089472 2.0672908706505484e-5 -8.9194607878424e-5 -7.434756525731033e-5 -0.0001420091543852741 -4.3184653347679125e-5 6.544703412829883e-6 -7.538106410470141e-5 1.86331667489439e-5 -0.00010952351703713121 -0.00010314248038161748 1.9458935187459633e-5 -5.233821911792701e-5 -2.5930452918372763e-5 0.00013726835050215593 0.00023969101630677973 3.140160652287883e-5 2.5562826979469418e-5 -0.00012730850403136048 0.00012282250418388783 1.2007980788213481e-5 0.00011319528448689582 -1.20675208538601e-5 0.0001450177110225478 0.0001020608065143743 -3.810458337943368e-5 2.0168361063725638e-5 2.6939714646461455e-5 6.502175847595293e-5 -6.859955146614754e-5; -2.743751652178511e-5 5.315397679351116e-5 -1.0648358013897727e-5 0.00011851450430822011 -0.0002129660610562495 -6.759765933838961e-5 -1.1421226602581078e-5 -4.480130196770788e-5 9.25276940934059e-5 5.8089833643861454e-5 -1.856899548602188e-5 -9.596183777193714e-5 -0.00011766571374482918 -2.1761691144091964e-5 3.262201787039239e-6 7.436698392393863e-5 0.000115520724236376 8.819818644679008e-5 -0.00010883868931432498 1.4202856206792839e-5 0.00010329184037527827 0.00010518649973801199 8.328943981450789e-5 7.675173723964437e-5 -6.852290648839619e-5 4.4015635964570995e-5 8.070784274150482e-5 -5.4513079206117805e-5 -2.971867293399432e-5 -1.9864148683154403e-5 -6.155328124168687e-5 -4.892748297879797e-5; -2.1489725152697205e-5 -0.00015552722896527583 0.00019622606061322158 -0.00016451997788233786 -6.3983233015366e-5 0.0001465548442110102 1.2668441119709399e-5 -2.3478353508018757e-5 -3.809515026941682e-5 0.00010612957277490609 0.00011675418801761195 5.423044723729624e-5 6.176664766545742e-5 4.225137960839082e-6 -3.3683151453352886e-5 -4.744176902304576e-5 6.321003758503475e-5 -1.799324007433968e-5 -0.00010824426561439067 -6.910467044062173e-5 -0.00020521218237545325 -5.9889473671411815e-5 8.361534429868791e-5 8.99767431446139e-5 -0.00012944281905502022 5.509265549053307e-5 -0.00014295629506077786 5.7340624441082883e-5 -0.0002024253596419856 -6.936328707806518e-5 9.264973702474355e-6 7.139428209203079e-5; -1.3973003473577549e-5 -7.229467682386991e-5 -0.00010857395481629049 -3.518936300145807e-5 -6.37817409316525e-5 -3.2924113651594245e-5 8.553332209509183e-5 -4.897959096902377e-5 -0.00012794686135216574 -3.031290343109865e-5 0.00017795554068093878 -6.727433155370086e-5 0.00018489857957853115 -5.036273959754154e-5 0.00013655321727660202 0.00016317345780939466 -0.00010737538631851611 7.397498125871111e-6 -4.3110075961657075e-5 1.9055343915038574e-5 -1.5229877857702393e-5 -9.167702027676007e-5 5.642816741405926e-5 7.187131469060173e-5 -0.00019103970207120396 -0.00015181914731706778 5.7526931601247294e-5 -6.502114210356469e-5 -3.7897274336641e-5 4.7285682927259034e-5 2.9783614268905648e-5 -8.229993345061725e-5; 2.1348263726572383e-5 0.00019593103952519854 -0.0001730680012768023 -0.00023009190754218572 -0.00014081180397606865 -7.521521893049899e-5 0.00012548626564258005 2.9877630550067172e-5 -0.00011977060767076478 -4.1329156192477656e-5 -5.355376560873044e-6 5.05722238814956e-5 0.00023267677916028644 -0.00015178559242467988 -0.00015484680607169628 -5.83979524888543e-5 8.730079589458691e-5 0.00011710080452775819 -0.00010621788514963778 0.00011553618987836676 6.311096695022113e-6 -0.00013035823194496114 -6.599409953178466e-5 -0.00019547100946492216 1.9184344819401102e-5 0.00011272727733757582 6.028035950311671e-5 -2.2231979346361465e-5 3.0041428091812468e-6 9.606815019669463e-5 5.839930981910665e-5 9.377279657219327e-5; 8.654240536960237e-5 -9.679508366769352e-6 -1.2449966562077143e-5 -6.511480024008213e-6 0.0001831853851395677 0.00014933272632295348 -0.00013400100851906845 8.87253454489678e-5 8.372212505110265e-6 -0.00010720607951991294 2.734602912785395e-5 -3.690337234767735e-5 -2.698114485997793e-6 -1.6433178644018615e-5 4.875864305706508e-6 7.802944764951691e-5 2.778854741298311e-5 -1.6960700122961257e-5 9.717483499078265e-5 3.22696405415422e-5 0.0001825977933545632 -1.298082679606303e-5 -0.0001256417917833633 -0.0001603972368688367 -0.00011260553128560677 1.5472537597436054e-5 -8.830781280304791e-5 -1.6826895262440497e-5 -0.00016992709697698742 0.00017024897988410903 -6.649563420107666e-5 0.00015882566822217778; -0.00016026768490841453 -5.96881865144881e-5 -0.00016436572251540683 8.986026738528719e-6 -0.0001578567818010266 -0.00028334947193208923 0.00030490678262732137 -5.4330190446719025e-6 0.0001509233051242222 -0.00021771365492305797 -9.741209133927448e-5 4.288434582709776e-5 -0.00019827367860989205 0.0001090311424254305 -0.0001313578314816766 6.181261948272348e-6 9.678427967307657e-5 5.488671457574693e-5 -7.10854866279335e-5 -2.9617408575645872e-5 -5.928808524325995e-5 -5.889685336437169e-5 -3.464622124150625e-5 -0.00015922194062430815 -8.839729627064799e-5 5.671601037325323e-5 6.471975292375178e-5 -7.67257071070318e-5 -7.062316500518546e-5 3.3553876061404087e-5 -6.818778557636587e-5 1.3787008253060824e-5; 4.310397092156558e-5 7.948842288131334e-5 0.0001280578774018004 5.579842148253752e-5 3.0406077013670984e-5 -0.0002328397192908729 3.2930279514033506e-5 -2.7428865695056573e-5 2.2438827028583977e-5 7.30529256830664e-5 -8.426040109808263e-5 3.128313003942891e-5 3.3998000828061566e-5 -0.00015383203096318448 1.2226465725306713e-5 -0.00015677648004240922 1.4032272003154227e-5 1.2952971912072321e-5 -1.8606123316352543e-6 -5.1267422139438066e-5 0.00011763829136869457 3.964035680071829e-5 -0.00014629298927358692 1.5841530722895898e-5 -7.159554913270665e-5 -8.903904633227337e-5 0.00011512677631942923 -0.0001428922212366282 0.00013446766136583326 -1.8172852622608155e-5 -0.00014388224623727337 -1.962443163248561e-5; -8.65774747906043e-5 5.578861579510604e-5 0.00014780699326753482 -3.297135502701401e-5 3.181605985833074e-5 -3.2355361541473054e-5 6.687055043972966e-5 -6.094002331861163e-5 -6.689781388172358e-5 0.00013579537677674722 -1.7313645217420523e-5 -7.056536218056128e-5 -2.1001931814442366e-5 -7.416705578703081e-5 0.00017029069182558163 -9.936827358228954e-5 4.6149936464873245e-5 -5.709596392603416e-5 -4.6532132033233516e-5 3.646580422742413e-5 -0.00015506942530532975 -3.3935373137445847e-6 0.00014253703171948387 -7.515302807927066e-5 6.20572425434589e-5 -0.00014082959045095197 -6.554928789751471e-5 -0.00010976515135166172 -0.00010992968258119968 8.995368604178162e-5 -7.531550748876039e-5 -1.810715024088875e-5; 1.9930135435158802e-5 -6.80834149042242e-5 -2.8259519554833663e-5 8.599298385351596e-5 -0.00011855043917048123 6.44537846293162e-5 0.00012089393353935211 6.72079455890601e-5 -6.485945991325429e-5 6.745165378934103e-5 1.650723029715731e-5 1.3779351005820834e-5 -9.043273380111884e-5 4.841394698278851e-5 -8.100811314388194e-5 -3.571658302998158e-5 9.957295483591642e-5 -0.00029601039054422655 -6.456173500364033e-5 0.00014680482641242805 5.004519747769304e-6 -1.0219629390023153e-5 5.327974727525208e-5 0.00012678860880814292 -6.218597477154013e-5 0.00010244502996475423 2.6607712838184203e-5 -9.77898681319626e-7 -0.00011677808137928136 -2.535044254354458e-5 0.00025653482376046744 8.155365007134607e-6; 8.444214954800652e-5 -7.973478450145823e-5 5.7915619711303084e-5 -4.611937763520427e-6 -7.53272713132607e-5 0.00013940768740433856 -5.829139144913777e-5 0.00019105345318387953 6.824982601977588e-5 -3.224646452287592e-5 -0.00013768439630218243 0.0001342059597868073 -0.0001625267129362047 -4.630708369920402e-6 -6.662935147466383e-5 -7.699184306835992e-5 0.00010632160165408311 8.298341463381181e-5 -5.383894922267028e-5 -0.00011916122702820599 -0.00019887349726338207 9.197994861490959e-6 0.00024060551850419213 -5.8555079429033586e-5 1.906759216802388e-5 -8.12113455118077e-5 4.027997044279145e-7 3.117537680958032e-5 2.7312252589065975e-5 -0.00011410034739816862 0.00012554669711101408 6.977763159960137e-5; -2.6970642551936044e-5 -5.16383052803974e-5 -7.709566952549214e-5 3.976933922380616e-5 0.0001597114741394335 0.0002278622113458839 7.824660761393266e-5 -5.336780040528824e-5 4.485862055439984e-5 -7.733736956147364e-5 2.124711036735719e-5 -7.97352487047142e-5 0.00011453554807808048 0.0001222746494982451 9.953238523444867e-5 -0.00014265066519732786 7.150195497665946e-5 0.0001889952535798844 0.00013138848490193883 -0.0002613524276006953 4.47424635290695e-5 -9.674480051621071e-5 -7.380936685371594e-5 0.00014843884755968626 8.728680305062444e-5 -0.0001428433907626126 -0.00015053439810293925 -0.00016254675674140273 3.996661226261196e-5 0.00011918153805304778 -1.3108355319173453e-5 -6.146868062792192e-5; 5.2313905571815484e-5 8.028437830088182e-5 3.2807561611295215e-5 -0.0001223508084234718 7.492201209114658e-5 -2.3703320520698114e-5 1.1301417972024054e-5 0.00018243247856897704 0.00010793441421864257 4.107947413157716e-5 7.586056696763375e-5 6.495890508684844e-5 -9.162779363523832e-5 -7.474993665656474e-5 -4.261668484961188e-5 1.4564347539440994e-5 -0.000169584404401949 -2.1451767492304155e-6 1.0515327363369368e-6 -1.245642627455509e-5 5.8256958682784355e-5 3.5254484654274357e-7 -3.986114509590332e-5 -0.00011903262250543962 6.478580277924968e-5 3.820039953151211e-5 -0.000234138519344391 -0.00011721914100181326 3.909907124633382e-5 -6.113622502572135e-6 -7.520086412969877e-5 -6.978897047648199e-5], bias = [2.4470986297103634e-9, -4.486171633779073e-9, 4.289294165548167e-9, -2.7112635745545532e-9, -5.610677512287686e-10, 5.185226115116201e-9, 3.564219428894111e-9, 3.0583530843838876e-9, -4.535258792853174e-9, -6.563250623524105e-10, -1.3737585633185215e-9, -3.166155631984477e-9, -1.1664454453943717e-9, 1.5528843917541885e-9, 2.0419411645424304e-9, 1.958131018300607e-9, -4.143645851916311e-10, -1.673634059790535e-9, 1.1677296168339012e-10, -2.733006558291094e-10, 3.5974790207193423e-10, -1.6850054653907896e-9, -1.62053234271928e-9, -9.675202870289435e-10, 8.412822194338764e-10, -5.085088146391879e-9, -1.3485700228045206e-9, -1.699150519998719e-9, 1.0892600660604129e-9, 2.522050938135189e-10, 1.104950167126229e-9, -1.0118396074669257e-9]), layer_4 = (weight = [-0.0007759790916503085 -0.0005594652996663374 -0.0007634324859760256 -0.0007441510511481122 -0.0005998261328145645 -0.0008160159031524356 -0.0006806099626513552 -0.0006887165075462413 -0.0007547918798115265 -0.0008092198954450234 -0.00046415710428470927 -0.0006224276182072884 -0.0006245563041293927 -0.0007577269465779897 -0.0006417219727636501 -0.000815880619147511 -0.0005578698720115691 -0.000745059716966179 -0.0007282690446669878 -0.0007689373353621276 -0.0006948533000908666 -0.0007068979650375055 -0.0006325799349924657 -0.0007894840019396982 -0.000717578486240638 -0.0006291958219311486 -0.000844068971647476 -0.0007797031496611646 -0.0004993153742124117 -0.0006621204061734572 -0.0005094688127114422 -0.0008117430526642916; 0.000197774347867792 0.00028596410076874024 0.00035073303261602234 0.0002348662850865889 0.0002823375195398884 3.099164077379006e-5 0.0002174407858761448 0.00041488866922802464 0.00024830249291574234 3.532181712075517e-5 0.00013127477211570605 0.0004593089834596943 0.00031329144752889393 6.363294673295771e-5 0.0003145027286874594 0.0004346150084707905 0.0003867470502939903 0.00019928663229676135 0.0004905015957982799 0.00015621010928036763 0.00017548317838212874 0.00019606415069184488 5.1401188142576223e-5 6.404353968026624e-5 0.00039498280451287824 -8.104125434128805e-6 0.0001597387681289586 0.0001911043514239188 0.0003377871027418083 0.00025737363615285384 0.00015186346634103991 0.0002459373568916567], bias = [-0.0006954116534923056, 0.0002296816529782467]))Visualizing the Results
Let us now plot the loss over time
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")
lines!(ax, losses; linewidth=4, alpha=0.75)
scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)
fig
endFinally let us visualize the results
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
s3 = scatter!(
ax,
tsteps,
waveform_nn_trained;
marker=:circle,
alpha=0.5,
strokewidth=2,
markersize=12,
)
axislegend(
ax,
[[l1, s1], [l2, s2], [l3, s3]],
["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
position=:lb,
)
fig
endAppendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.11.8
Commit cf1da5e20e3 (2025-11-06 17:49 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 4 × AMD EPYC 7763 64-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 4 default, 0 interactive, 2 GC (on 4 virtual cores)
Environment:
JULIA_DEBUG = Literate
LD_LIBRARY_PATH =
JULIA_NUM_THREADS = 4
JULIA_CPU_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0This page was generated using Literate.jl.