Training a Neural ODE to Model Gravitational Waveforms
This code is adapted from Astroinformatics/ScientificMachineLearning
The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl
Package Imports
using Lux,
ComponentArrays,
LineSearches,
OrdinaryDiffEqLowOrderRK,
Optimization,
OptimizationOptimJL,
Printf,
Random,
SciMLSensitivity
using CairoMakieDefine some Utility Functions
Tip
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector
function one2two(path, m₁, m₂)
M = m₁ + m₂
r₁ = m₂ / M .* path
r₂ = -m₁ / M .* path
return r₁, r₂
endNext we define a function to perform the change of variables:
@views function soln2orbit(soln, model_params=nothing)
@assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
if size(soln, 1) == 2
χ = soln[1, :]
ϕ = soln[2, :]
@assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
p, M, e = model_params
else
χ = soln[1, :]
ϕ = soln[2, :]
p = soln[3, :]
e = soln[4, :]
end
r = p ./ (1 .+ e .* cos.(χ))
x = r .* cos.(ϕ)
y = r .* sin.(ϕ)
orbit = vcat(x', y')
return orbit
endThis function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d_dt(v::AbstractVector, dt)
a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
b = (v[3:end] .- v[1:(end - 2)]) / 2
c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
return [a; b; c] / dt
endThis function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d2_dt2(v::AbstractVector, dt)
a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
return [a; b; c] / (dt^2)
endNow we define a function to compute the trace-free moment tensor from the orbit
function orbit2tensor(orbit, component, mass=1)
x = orbit[1, :]
y = orbit[2, :]
Ixx = x .^ 2
Iyy = y .^ 2
Ixy = x .* y
trace = Ixx .+ Iyy
if component[1] == 1 && component[2] == 1
tmp = Ixx .- trace ./ 3
elseif component[1] == 2 && component[2] == 2
tmp = Iyy .- trace ./ 3
else
tmp = Ixy
end
return mass .* tmp
end
function h_22_quadrupole_components(dt, orbit, component, mass=1)
mtensor = orbit2tensor(orbit, component, mass)
mtensor_ddot = d2_dt2(mtensor, dt)
return 2 * mtensor_ddot
end
function h_22_quadrupole(dt, orbit, mass=1)
h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
return h11, h12, h22
end
function h_22_strain_one_body(dt::T, orbit) where {T}
h11, h12, h22 = h_22_quadrupole(dt, orbit)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
h11 = h11_1 + h11_2
h12 = h12_1 + h12_2
h22 = h22_1 + h22_2
return h11, h12, h22
end
function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
# compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2
@assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"
h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
@assert mass_ratio ≤ 1 "mass_ratio must be <= 1"
@assert mass_ratio ≥ 0 "mass_ratio must be non-negative"
orbit = soln2orbit(soln, model_params)
if mass_ratio > 0
m₂ = inv(T(1) + mass_ratio)
m₁ = mass_ratio * m₂
orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
return waveform
endSimulating the True Model
RelativisticOrbitModel defines system of odes which describes motion of point like particle in Schwarzschild background, uses
where,
function RelativisticOrbitModel(u, (p, M, e), t)
χ, ϕ = u
numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
denom = sqrt((p - 2)^2 - 4 * e^2)
χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = numer / (M * (p^(3 / 2)) * denom)
return [χ̇, ϕ̇]
end
mass_ratio = 0.0 # test particle
u0 = Float64[π, 0.0] # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4) # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize) # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, eLet's simulate the true model and plot the results using OrdinaryDiffEq.jl
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)
axislegend(ax, [[l, s]], ["Waveform Data"])
fig
endDefining a Neural Network Model
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but in case you do use them, make sure to mark them as const.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,
const nn = Chain(
Base.Fix1(fast_activation, cos),
Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-5.868951f-5; 3.6285262f-5; 3.6655914f-5; -4.2065036f-5; -7.48239f-6; 0.00019620954; 0.0001031019; 3.760482f-5; 0.00014143714; -0.00016752821; -7.314752f-5; -9.9751014f-5; -0.00012441089; -1.4491298f-5; 0.00014400462; 8.813158f-5; -0.00018816076; -5.782062f-5; 5.914003f-5; 2.7802338f-5; -1.6048127f-5; -0.00016947614; 8.514625f-6; -2.9371744f-5; 0.00012969084; 0.00015252588; 4.728897f-5; -2.1375905f-5; -0.00015740849; 3.0744224f-5; -0.00018117452; -0.00010200453;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[0.00012328928 -0.00011305824 9.939058f-5 -5.7495323f-5 -1.9382846f-6 -4.1959647f-5 0.00011226154 9.316581f-6 -8.588811f-5 -0.00014170814 0.00012604048 -4.2350766f-6 0.00011804282 -8.5500316f-5 0.00018671025 9.909709f-5 -0.00013797806 -0.00012486496 1.7599357f-5 5.7995654f-5 -9.400748f-5 5.3246717f-5 7.8696736f-5 0.00012624613 -0.000107889064 4.8032678f-5 3.733268f-5 -0.00014310461 -9.832556f-6 0.00011458232 2.5654017f-5 9.725234f-5; -0.00014139763 0.000120298246 8.293313f-5 5.2380547f-5 -0.00010083891 5.16136f-5 2.6268342f-5 8.4341504f-5 -0.00024429947 0.00011377899 -5.4242864f-5 6.856191f-5 -0.00013008804 9.0455236f-5 9.80712f-5 1.899015f-5 -6.819876f-5 0.00013702053 -0.00015590912 -8.328184f-5 5.2138017f-5 -4.4104898f-5 7.83458f-5 2.1596117f-5 4.7320817f-5 -0.00010159499 0.00012590224 6.177177f-5 -1.2565925f-5 -3.446826f-5 0.000105765685 9.634006f-5; -6.388669f-5 -8.5175285f-5 -0.00023816852 8.874694f-6 -8.79041f-5 3.2697593f-5 6.6784465f-5 8.211545f-5 -7.077687f-5 0.00017295373 8.639082f-5 0.00013912069 -8.79524f-5 -5.3954554f-5 7.660241f-5 -6.70179f-5 -6.641161f-5 -9.1192844f-5 0.00016900279 -2.9931794f-5 0.00030012196 9.289288f-5 -5.124021f-5 6.26553f-5 3.5185916f-5 9.448866f-5 7.306056f-6 -1.7788743f-5 -7.040835f-6 2.3803017f-5 0.00013429798 -0.00011979841; -4.1049745f-5 9.290191f-5 -7.5095566f-5 3.5090645f-5 -0.00020064232 -7.114079f-6 7.829913f-5 4.117719f-5 9.779681f-5 0.00014859403 8.577327f-5 -0.0001212092 -7.1258226f-5 6.578547f-5 3.436451f-5 -0.00015495514 -7.587664f-5 5.4180062f-5 8.386737f-5 -0.00015058396 -4.5350927f-5 -6.551489f-5 6.6308385f-5 3.325118f-5 2.1818091f-6 0.00012973975 6.3712876f-5 -0.00017302365 0.00019228036 0.00013784443 8.749255f-8 1.5883195f-6; 3.2156833f-5 -1.212215f-5 -0.00015718101 -0.00014344213 3.9359064f-5 -0.00013441143 2.5618278f-6 0.00013734741 0.00018496074 -3.9056587f-5 -1.1967232f-5 -0.00010370442 0.00018230986 -0.0002034475 -0.00011232518 4.2437558f-5 -0.00010749117 -5.6152905f-5 5.3301264f-5 -8.8508954f-5 -4.9819133f-5 -5.7834386f-5 1.7175267f-5 -0.000107547996 0.00014677514 5.991757f-5 -0.00012150744 8.909476f-5 -0.000114982446 7.865978f-5 -3.0946772f-5 -3.932619f-5; -0.00017311682 0.00012021473 -6.9732516f-5 -5.2942312f-5 -0.000103460305 5.415355f-5 0.00014616383 3.2489872f-5 7.543502f-5 3.8781185f-5 4.579826f-5 -0.00010485654 -1.2580638f-5 -0.00011007918 -0.0001707262 6.628634f-5 1.1794303f-5 -9.0564026f-5 5.719162f-6 4.7330304f-5 8.622483f-5 -8.506119f-5 -0.000101030615 -0.00018669173 0.00016195697 -0.00022005347 0.00019001117 -0.00012272576 -0.000100450794 0.00017310334 0.0002473452 6.271079f-5; 0.00013467607 -0.000102757345 8.874152f-5 8.740617f-5 -4.356251f-5 -5.182652f-6 -4.032613f-5 0.00010521386 -0.00013287824 -2.5890882f-5 0.00016001124 -3.5582834f-5 -8.863206f-5 -0.00010825045 -6.457374f-5 6.0135022f-5 5.9088183f-5 -7.215269f-5 4.9986604f-5 3.8645692f-5 1.5903293f-5 1.1364923f-5 1.0495769f-5 -3.269266f-5 -6.632919f-5 3.46666f-6 -1.289032f-5 -4.386477f-5 0.0001226451 -9.095794f-5 3.7455826f-5 8.13172f-5; 9.0063324f-5 -0.00014324574 -6.4903965f-5 3.2593813f-5 0.000104841754 5.3289004f-5 -4.7160574f-5 2.1843256f-8 3.540309f-5 1.09056755f-5 -8.338212f-5 0.00010861577 7.451389f-5 -0.00019335863 1.1344998f-5 -2.328102f-5 -6.858811f-5 4.9362287f-5 -7.6108416f-5 4.642852f-5 3.1597363f-5 -0.00014144942 -2.1069427f-5 -1.1044907f-5 1.8765977f-5 -4.5926707f-5 -4.8602105f-5 0.00014393499 -7.770757f-5 -0.00010974319 -2.3314531f-5 6.888578f-5; 8.6866865f-5 -0.00016121732 -9.982105f-5 7.783903f-6 -0.0001940663 -8.975614f-5 -8.851538f-5 4.7137582f-5 2.8330633f-5 0.00013791837 6.779042f-5 -2.2319184f-5 9.292545f-5 1.583353f-5 -5.101588f-5 0.00014746454 4.250207f-5 5.0230003f-5 0.00010484992 0.00016828868 -8.411236f-5 -4.2404467f-5 0.00012649302 -5.238237f-5 -6.172466f-5 3.6915295f-5 5.507298f-6 0.00015947972 5.7913844f-6 -5.978612f-5 -0.00011471446 -5.0386654f-5; -0.00016913158 -1.0598158f-5 5.0987175f-5 2.0808096f-5 3.9173312f-5 0.00010633199 -0.0001895122 -4.504722f-5 0.00019766342 8.438352f-5 -5.120506f-6 2.3006402f-5 1.682992f-5 5.4472028f-5 -3.302464f-5 4.7844962f-5 -0.00015433073 9.8562625f-5 0.00020744959 0.000110351364 -7.948064f-5 -6.361286f-5 0.00013332862 -4.4969365f-5 6.457506f-5 -8.7826986f-5 -7.439215f-5 -7.5952735f-6 -0.00012380657 -4.079802f-5 0.00012547121 7.877214f-5; 8.0006954f-5 0.000101172314 -1.0401518f-5 3.396786f-6 2.8361741f-5 -6.688284f-5 -9.338393f-5 0.0001138434 -0.00020551421 5.7886165f-5 9.387885f-5 -0.00011414528 7.631809f-5 0.0001481177 2.652371f-5 -5.971536f-5 6.7364985f-5 1.254162f-5 0.00011353261 -0.00018189818 -4.884843f-5 -0.000114082955 -0.0001766219 1.9214133f-6 1.8586667f-5 -0.00013228413 -4.0953943f-5 3.5374516f-5 -9.6150055f-5 -2.2723274f-5 -0.00010598757 -0.00012690791; 1.0848925f-5 3.3644355f-5 -4.1696654f-5 1.8290404f-5 0.000102130325 0.00021215138 -7.5661286f-5 -3.472082f-5 1.4599244f-5 0.00015541958 5.4402655f-5 -4.525544f-5 1.4622424f-5 -9.577928f-5 5.8834863f-5 3.6669477f-5 -3.5197027f-5 -6.3191655f-5 2.1295795f-5 -1.1221417f-5 2.4254966f-6 -0.00014368538 -1.0846837f-6 1.4944944f-5 2.0698153f-5 -9.163324f-6 4.2213574f-5 5.7670404f-5 -0.00020139178 -9.583344f-5 0.00013028095 -2.8699144f-5; -6.383841f-5 -0.00010591764 1.105367f-5 -0.00014593566 1.25123015f-5 0.00015908283 0.00011703815 8.231013f-6 -1.1926152f-5 -5.9326263f-5 -1.1506812f-5 -7.236818f-6 0.00021539809 4.6550365f-5 -9.789975f-5 1.350674f-5 0.00011245885 1.044738f-5 4.2214764f-5 2.805077f-5 -5.465282f-6 5.4114453f-6 -0.00018144099 2.0219537f-5 -1.0519957f-5 0.000109990484 -2.1440934f-5 0.0001171616 3.2330896f-5 -4.481268f-5 6.320341f-5 -0.0002840354; 0.00012080541 5.3312913f-5 -3.2406057f-5 0.00016391829 2.5039042f-5 4.3557488f-5 -0.00024164333 5.2704803f-5 -8.801335f-5 3.8969163f-5 -9.238303f-5 -7.221811f-5 -3.9992183f-6 2.1802376f-5 0.00012779633 -0.0001324967 4.873456f-5 -0.00015077672 -2.6895846f-6 5.5665965f-5 -8.400948f-5 0.00012437192 -2.0378002f-5 -0.00011151485 2.0811136f-5 -0.00013147849 3.0082632f-5 -0.00013926037 -0.00017476808 -5.299206f-5 -1.7388418f-5 6.112418f-5; -0.00018713147 -0.00012292802 -1.2475185f-6 -0.00014574161 -5.6159047f-6 -0.00014663421 -0.00018739252 -0.00015729819 0.00019787284 -7.79155f-5 -0.00012627796 5.2984236f-5 5.158587f-5 7.726249f-5 -3.5550587f-5 -2.0686917f-5 -1.1632706f-5 -6.53649f-5 9.417723f-6 -8.903998f-5 1.561374f-5 0.00014917571 -9.7061406f-5 0.00014415369 3.6802703f-5 2.4126077f-5 3.6546273f-5 2.7156507f-6 4.038879f-5 7.374752f-5 -3.161805f-5 -0.00013393328; -0.00012459626 -0.00010894018 -5.0356542f-5 1.7481892f-5 5.5709206f-5 -7.160802f-5 7.654819f-5 -0.00010407953 -5.293935f-5 -3.507798f-5 3.9522492f-5 -0.00018916027 6.604703f-5 5.745103f-5 0.00016770963 2.6325457f-5 0.00015261058 3.3660217f-5 8.14223f-5 1.252462f-5 -2.1032736f-5 -0.00014635941 0.00012883412 6.9885486f-5 -0.00017239594 7.353928f-5 3.065145f-5 -9.312048f-5 -2.9819428f-5 0.00010028246 -3.8737227f-5 -4.0185256f-5; 0.0001908221 -0.00011457844 -0.000100665835 -0.00014331963 -3.424999f-5 -0.00017243187 0.00010496385 -3.8121456f-5 0.00023567415 -0.00017705547 4.3688f-6 -0.00010087062 6.0805007f-5 -0.00029410527 3.023738f-5 -0.00013821837 8.548091f-5 -0.00022977951 -7.1902155f-5 -4.5729234f-5 0.00012046002 -8.4572755f-5 7.0480375f-5 6.851904f-5 -0.00018927114 0.00017363868 6.6260785f-5 1.225494f-6 9.335593f-6 8.355145f-5 -8.642984f-7 5.7350408f-5; 4.8679263f-5 -2.8749882f-5 -2.2540547f-5 -9.921338f-5 -3.966692f-5 -6.834207f-5 3.2386124f-5 5.5585183f-6 -8.432408f-5 4.0406314f-5 2.8880302f-5 8.3248764f-5 9.8940145f-5 -1.5917985f-5 3.077972f-5 -9.637456f-5 -8.3670355f-5 -3.237899f-5 0.00016774258 3.212507f-5 -5.704602f-5 0.00022912111 0.000108575885 4.1829222f-5 9.884907f-6 -0.0002914564 -8.790024f-6 6.504117f-5 0.0001972975 6.985354f-5 0.00011527349 -0.00015511471; 0.00013527913 2.4967223f-5 -0.0001592764 -5.7998488f-5 -0.00014677281 0.00013748789 -2.4977471f-5 3.2616543f-5 -7.0066795f-5 0.0002006726 -5.945825f-5 -5.3749955f-5 3.6338963f-5 -7.461399f-6 1.4294172f-5 0.00015402147 2.1794254f-5 1.5206236f-5 6.450676f-5 -7.404984f-5 -6.4746814f-6 -5.7515826f-5 -0.00011045788 6.9958805f-5 -8.033694f-5 0.00013461517 9.587471f-5 -0.00014326707 8.1706436f-5 -1.00571615f-5 9.975296f-6 -6.3866646f-5; 0.00013934137 -0.00015093405 -5.187664f-6 0.00024708972 -0.00015344613 6.038525f-5 0.00012646292 5.8133777f-5 1.26935465f-5 0.00015176415 0.00010785818 4.4595872f-5 -0.00014385587 -7.099625f-5 -3.0000509f-5 3.258125f-5 -4.932778f-6 -0.00017068125 2.983433f-5 0.00014575289 -6.815942f-5 -7.3459596f-5 7.493695f-5 -0.00025284977 0.00021409323 3.0042223f-5 -9.697306f-5 -2.7980657f-5 1.14072245f-5 0.00018513067 0.00021103806 -8.581455f-5; -2.9126453f-5 -8.79208f-5 -6.259572f-5 -9.1463095f-5 3.72157f-5 4.2053172f-5 7.986989f-5 -0.00015018284 0.00012308043 6.8743415f-5 -7.7759534f-5 -0.00017673433 1.2666736f-5 4.7245347f-5 6.21105f-5 -8.4223035f-7 -9.182125f-5 8.272559f-5 -4.126667f-5 0.00012890984 -2.3127828f-5 -0.000115941366 -0.00012025178 -0.00015957502 0.0001072272 -8.164201f-5 7.104967f-5 -9.540023f-5 9.6295065f-5 -0.0001498752 -7.1580536f-5 -4.683234f-5; 1.1691824f-5 0.0001268063 4.237228f-5 -6.246823f-5 0.000113639326 -2.090651f-5 -4.271379f-5 2.5084848f-5 5.591349f-6 0.00013802515 -6.9575462f-6 8.011797f-5 -1.7731432f-5 -8.566396f-5 -0.00016573696 -6.0741644f-5 -3.7349193f-5 0.000122755 -9.660515f-5 0.0001962756 5.7319696f-5 3.6293783f-5 -9.809646f-6 8.417243f-5 0.00012764607 4.720374f-5 -8.314907f-5 -1.2194276f-6 -5.7758694f-5 -0.00012926206 0.00010784973 -5.5974364f-5; 6.484849f-5 0.00015361636 -0.00024463402 -0.00011045055 -2.7386031f-5 8.991243f-5 -2.8265833f-5 1.3822367f-5 -0.0002780574 -3.387995f-5 8.614107f-5 -2.8122091f-5 -0.00023033132 4.377676f-5 -2.1725564f-5 -7.536864f-5 -0.00013361202 -2.477147f-6 6.7181004f-6 7.404804f-5 3.3403863f-5 -6.9149835f-5 4.2603027f-5 0.000237951 9.412982f-5 2.3926103f-5 -2.6393305f-5 -5.5977074f-5 7.413529f-5 -6.9184505f-5 6.052253f-5 9.552933f-5; -1.9063413f-5 6.7672045f-5 0.00022506382 6.779937f-5 -0.00018374601 0.00013090769 3.2844568f-5 -1.48462f-5 8.746952f-5 8.932585f-5 -6.465212f-5 -2.319798f-5 8.634156f-5 -5.7415582f-5 -0.00014942305 -3.1768363f-5 4.262388f-5 6.710529f-5 -5.5709887f-5 -4.0718438f-5 8.144772f-5 -7.47251f-5 -6.099019f-6 0.00024498388 -0.00016948147 2.1410817f-5 -0.0001006889 7.607422f-5 -3.5490513f-5 -5.482144f-5 -0.00020959142 8.9870446f-5; -8.8129585f-5 -0.00020907012 -0.0001956923 0.0001256306 -4.0407547f-5 -0.0001418744 -6.285083f-5 9.814451f-5 -4.6871457f-5 9.967307f-6 -3.714282f-5 -3.1130614f-5 -4.028301f-5 -3.171612f-5 -8.23997f-5 -2.4316892f-5 1.9469219f-6 -1.6959664f-6 -0.00015840025 -2.2848486f-5 7.1421564f-5 3.12088f-6 3.9499796f-6 -4.2233627f-5 3.3298435f-5 7.5839453f-6 -9.232882f-6 0.00013910377 -6.943034f-5 -7.133081f-5 0.00012214003 -5.0526473f-6; 0.00026266917 -0.00028006025 -0.00014826821 2.2292608f-5 -4.3452823f-5 8.5096464f-5 2.4873794f-5 -5.806379f-5 -5.692843f-5 2.3298702f-5 2.9690054f-5 -4.3845048f-5 1.5419529f-6 0.000118201846 -0.00013532596 -4.50975f-5 9.608246f-5 0.00013246459 -1.1665182f-5 8.5572625f-5 -0.00012562121 6.3838347f-6 5.594272f-5 -0.00013150563 2.5614307f-5 2.774024f-5 1.1193586f-5 3.6824993f-5 -7.599422f-5 -6.249666f-5 -0.0001444575 -4.220859f-5; 9.981271f-5 0.00015448408 7.702112f-5 -0.00015293343 0.00016595003 0.00017107437 0.00017680117 0.00011826654 5.9306953f-5 1.9482228f-5 -0.00011225427 6.852817f-6 2.108333f-5 9.73624f-5 1.8174862f-5 -0.00021875682 -7.8591474f-5 -7.88889f-5 -0.00024902943 -3.154428f-7 1.3103614f-5 -1.0006767f-5 -9.565838f-5 5.1320145f-5 -3.9431572f-5 1.9925235f-5 7.3020856f-6 9.934186f-5 5.377687f-5 0.00017120017 -5.305116f-5 -5.789181f-5; 3.367124f-5 -3.9823367f-6 4.1484847f-5 -6.1296516f-5 -8.09013f-5 6.601712f-5 0.00010418966 2.154894f-5 -7.849169f-5 0.00022588088 -3.9021037f-5 7.079356f-5 2.7066277f-5 -0.000115919276 4.0263352f-5 0.0001154647 -4.6655554f-5 -5.4871653f-5 0.00011882237 0.00024594853 -5.2659343f-5 3.3900524f-5 -8.496851f-5 -0.000102331345 1.5724954f-5 2.2061815f-5 3.6928125f-6 3.261098f-5 0.00013023124 -0.00012269965 6.506649f-5 -9.758647f-5; -6.970333f-5 -4.1749485f-5 -0.00012897399 6.184022f-5 0.00011318685 -0.00015233363 -0.0001514975 9.965362f-5 -6.081135f-5 2.0661642f-5 -1.9156523f-5 1.2856651f-5 -7.512478f-5 -0.00027024455 7.786232f-5 -6.604869f-5 3.4885048f-5 0.0001949114 -0.0001396375 0.00010156272 6.325037f-5 -2.94779f-5 -4.2597432f-5 -0.00013203167 -5.211327f-5 -6.0119787f-6 -4.473215f-5 4.2985772f-5 1.2714572f-5 -0.00014781374 -1.0788631f-5 9.517144f-5; -0.00021286498 -0.00017868065 8.9929425f-5 5.61966f-5 -0.0002149491 -2.9230161f-5 -0.00018269586 1.2869742f-5 0.00014750811 3.2791672f-6 3.46024f-5 -2.5167219f-5 4.7072488f-5 -1.375018f-6 4.751032f-5 -0.00010122801 -4.8303733f-5 -7.541233f-5 2.2445496f-5 -3.8739083f-5 -7.5284595f-5 -9.120896f-5 -9.7884265f-5 5.2033556f-5 -3.258942f-6 8.356315f-8 0.00013561583 7.608184f-5 -0.000117952695 -0.00016935455 1.2234212f-5 -0.00027070686; 3.265854f-5 1.8568258f-5 -3.491831f-5 -0.00013983861 0.00018024382 -4.3523858f-5 -7.391059f-5 0.00010660827 -7.7300436f-5 -9.840445f-5 -5.4744237f-6 -1.6762713f-5 2.9303412f-5 3.828844f-5 0.00011665308 0.00016962968 2.488831f-5 -0.0001971404 -6.2967185f-5 -3.2274456f-6 8.160627f-5 -2.2455348f-5 -6.809157f-5 -0.0001760613 -0.00013993263 -1.0404073f-5 -9.993289f-5 -0.00011035738 -3.459504f-5 -5.1800278f-5 -0.00012289238 6.965976f-5; -5.8823014f-5 -0.000109292145 3.9004848f-5 2.248625f-5 0.0001319916 0.00019481711 -1.6687309f-5 -8.381696f-5 -4.569045f-5 6.8923015f-5 7.857321f-6 6.6797846f-5 -5.929278f-5 -0.00021055521 0.00012817828 -9.6872005f-5 9.1571366f-8 -3.3036064f-5 -1.6303978f-6 -0.000116778494 7.6201184f-5 0.00011252426 0.00011171902 -4.2051473f-5 -2.6488306f-5 1.1094034f-5 -0.000105284205 5.8964324f-5 6.659322f-5 2.0438181f-5 -4.7330814f-5 -5.601734f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[0.00012689234 0.00010037768 -2.4956792f-6 0.00013888963 5.333373f-5 -0.00015276503 -3.558331f-5 -6.716371f-5 -0.00011206198 -0.00011474121 -1.0124363f-5 1.0130339f-5 -0.0001152871 0.00019022459 7.046893f-5 -6.2364525f-5 -0.00011723436 -1.6704762f-5 -4.5243974f-6 -7.986062f-5 -0.00022531294 0.00015247312 0.00011108689 -8.460949f-5 -7.919108f-5 -2.964382f-7 -0.00010932535 4.3170847f-5 6.148574f-5 2.1335994f-5 5.2267424f-6 -2.1450627f-5; 4.1156654f-5 -0.000112452144 -1.862208f-6 -6.414123f-5 -5.598169f-5 -0.00014135818 -1.4242884f-5 -0.00017031444 7.542423f-5 -0.0001084549 -0.00015390282 1.6510548f-5 -2.8025737f-5 4.656947f-5 7.959394f-5 -3.07173f-5 -0.00012988797 0.0001360505 2.1469525f-5 2.9560219f-5 -9.759055f-5 -5.3986673f-6 -5.7862704f-5 -7.972942f-5 -0.00011542545 -9.401273f-5 8.4954445f-5 4.9806305f-5 0.00022546506 1.0318631f-5 0.00015429611 -0.00010876609], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))Similar to most deep learning frameworks, Lux defaults to using Float32. However, in this case we need Float64
const params = ComponentArray(f64(ps))
const nn_model = StatefulLuxLayer(nn, nothing, st)StatefulLuxLayer{Val{true}()}(
Chain(
layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
layer_2 = Dense(1 => 32, cos), # 64 parameters
layer_3 = Dense(32 => 32, cos), # 1_056 parameters
layer_4 = Dense(32 => 2), # 66 parameters
),
) # Total: 1_186 parameters,
# plus 0 states.Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses
where,
function ODE_model(u, nn_params, t)
χ, ϕ = u
p, M, e = ode_model_params
# In this example we know that `st` is am empty NamedTuple hence we can safely ignore
# it, however, in general, we should use `st` to store the state of the neural network.
y = 1 .+ nn_model([first(u)], nn_params)
numer = (1 + e * cos(χ))^2
denom = M * (p^(3 / 2))
χ̇ = (numer / denom) * y[1]
ϕ̇ = (numer / denom) * y[2]
return [χ̇, ϕ̇]
endLet us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
axislegend(
ax,
[[l1, s1], [l2, s2]],
["Waveform Data", "Waveform Neural Net (Untrained)"];
position=:lb,
)
fig
endSetting Up for Training the Neural Network
Next, we define the objective (loss) function to be minimized when training the neural differential equations.
const mseloss = MSELoss()
function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
return mseloss(pred_waveform, waveform)
endWarmup the loss function
loss(params)0.0006925975213593596Now let us define a callback function to store the loss over time
const losses = Float64[]
function callback(θ, l)
push!(losses, l)
@printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
return false
endTraining the Neural Network
Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
optprob,
BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
callback,
maxiters=1000,
)retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-5.8689511206429835e-5; 3.6285262467517375e-5; 3.665591430030331e-5; -4.206503581370491e-5; -7.482390174121515e-6; 0.00019620954117237268; 0.00010310189827556269; 3.760481922654195e-5; 0.00014143713633542953; -0.00016752821102248503; -7.314752292582116e-5; -9.97510142040912e-5; -0.00012441088620115525; -1.4491298315981026e-5; 0.00014400461805055148; 8.813157910476e-5; -0.00018816076044469136; -5.782061998613464e-5; 5.914003122592922e-5; 2.7802338081504905e-5; -1.6048126781242114e-5; -0.0001694761449469799; 8.514624823864108e-6; -2.9371743948992383e-5; 0.000129690844914847; 0.0001525258849141709; 4.7288969653905426e-5; -2.1375904907477697e-5; -0.00015740848903058124; 3.074422420463922e-5; -0.00018117451691046036; -0.00010200453107240679;;], bias = [-4.128433468706353e-17, 6.405205017231377e-17, 5.908817645391201e-17, -1.8962535909762085e-17, 3.586103737078517e-18, 4.01134399917652e-16, 2.478518064966306e-16, 3.120512449060804e-18, -1.5194274933098374e-16, -2.78354417502918e-16, -4.995007324628705e-17, -1.6195881365207218e-16, 9.244079720141831e-17, -8.53337364163451e-18, -2.2690681382999782e-17, 5.911606172422847e-17, 1.1682985601843932e-16, -2.900780014897287e-17, 5.783685107403283e-17, 2.708960329550871e-17, -8.880719437412873e-19, -1.548318180229551e-16, 1.1796193712597408e-17, -2.4529914455311634e-17, 3.227547010700931e-17, 4.723359695397834e-17, -1.673818624770948e-17, 6.19186583063849e-18, -2.342510858746274e-16, 4.9041567264429875e-17, -2.601497127929737e-16, -5.391055840265236e-17]), layer_3 = (weight = [0.0001232907700098953 -0.00011305675141253455 9.939206504290044e-5 -5.749383560820452e-5 -1.936797413383638e-6 -4.1958159978714575e-5 0.00011226303033557321 9.318068099071285e-6 -8.5886623107101e-5 -0.0001417066494256476 0.00012604196965488462 -4.233589403763346e-6 0.00011804430980494542 -8.549882911817065e-5 0.00018671174186540813 9.909857474061402e-5 -0.00013797657154707906 -0.00012486347693690142 1.760084453356865e-5 5.79971412874673e-5 -9.400599517599576e-5 5.324820378820999e-5 7.869822295409739e-5 0.00012624761732089407 -0.00010788757642937592 4.803416525039741e-5 3.733416744884112e-5 -0.00014310312397052439 -9.831069088863809e-6 0.00011458380615598594 2.5655504235438228e-5 9.725382844711265e-5; -0.00014139620510635064 0.00012029966852621175 8.293455572326666e-5 5.238196932820582e-5 -0.00010083748533164093 5.161502155001421e-5 2.6269765259861375e-5 8.434292645885606e-5 -0.00024429805092753774 0.00011378041050389106 -5.424144150113873e-5 6.856333162628646e-5 -0.000130086616661937 9.045665898353844e-5 9.807262209702317e-5 1.8991572107702312e-5 -6.81973357476172e-5 0.0001370219509791853 -0.00015590769511421312 -8.328041770967628e-5 5.213943983305295e-5 -4.410347504549062e-5 7.83472245185921e-5 2.1597540000737858e-5 4.732223929946526e-5 -0.00010159356646712953 0.00012590366201235871 6.177318970200646e-5 -1.2564502538083238e-5 -3.446683662135137e-5 0.00010576710734508408 9.634148260571574e-5; -6.388498252942429e-5 -8.517357594552533e-5 -0.00023816681016921153 8.87640259737273e-6 -8.790239474489366e-5 3.269930209044999e-5 6.67861739826794e-5 8.211716237259009e-5 -7.077516159495534e-5 0.00017295544136560804 8.639252961859729e-5 0.00013912239853073273 -8.795069255154526e-5 -5.395284563871122e-5 7.66041183566092e-5 -6.701619088869054e-5 -6.640989989263131e-5 -9.119113486246401e-5 0.00016900449451771356 -2.9930085594393847e-5 0.0003001236681199165 9.289458727709266e-5 -5.1238502357015523e-5 6.265700982897485e-5 3.518762502483074e-5 9.449037212891125e-5 7.307764798162787e-6 -1.7787034386280347e-5 -7.03912622282745e-6 2.3804725850489293e-5 0.000134299690097226 -0.00011979670046582478; -4.104870961309667e-5 9.290294335660339e-5 -7.509453055245059e-5 3.5091680251935354e-5 -0.0002006412810485917 -7.113043488801132e-6 7.830016190855267e-5 4.117822616099193e-5 9.7797843841539e-5 0.0001485950656162374 8.577430205623298e-5 -0.00012120816747920132 -7.125719050673498e-5 6.578650434239621e-5 3.43655469579911e-5 -0.00015495410254952896 -7.58756046023446e-5 5.418109728925608e-5 8.386840861494393e-5 -0.00015058292549365363 -4.534989195672013e-5 -6.551385515632282e-5 6.630942014017e-5 3.325221630366349e-5 2.1828445598610412e-6 0.00012974078933911402 6.371391143620727e-5 -0.00017302261906006525 0.00019228139942714903 0.0001378454639736453 8.852798902251986e-8 1.5893549290896502e-6; 3.215440903711446e-5 -1.2124574258840581e-5 -0.0001571834376404047 -0.00014344455246776162 3.935663937279069e-5 -0.00013441385215573655 2.5594036929572144e-6 0.00013734498372338868 0.0001849583119297562 -3.905901155977113e-5 -1.1969656388305602e-5 -0.00010370684447427385 0.00018230743318834675 -0.00020344992788471223 -0.00011232760278236304 4.243513341914587e-5 -0.00010749359478668267 -5.6155329165837464e-5 5.3298840267028114e-5 -8.851137807089731e-5 -4.982155713098992e-5 -5.783681024641282e-5 1.717284305465921e-5 -0.00010755042001567671 0.00014677271849048763 5.9915145012010055e-5 -0.00012150986129147715 8.909233894934955e-5 -0.000114984869814551 7.865735698994366e-5 -3.0949196461189274e-5 -3.9328614893221044e-5; -0.000173117336071285 0.0001202142120928476 -6.973303605262466e-5 -5.294283253532171e-5 -0.00010346082488466353 5.4153028049789225e-5 0.0001461633110193524 3.24893517427119e-5 7.543449830338403e-5 3.878066492690885e-5 4.5797740331196957e-5 -0.00010485705932295078 -1.2581158425729136e-5 -0.00011007969804266763 -0.0001707267183337355 6.62858183386719e-5 1.1793782523475382e-5 -9.056454653544592e-5 5.718641677973734e-6 4.732978415633267e-5 8.622430688775781e-5 -8.50617106880042e-5 -0.00010103113525472879 -0.000186692250263135 0.00016195644640903907 -0.00022005398566760022 0.0001900106472203482 -0.00012272628007848105 -0.0001004513141924503 0.00017310282057225423 0.0002473446693511706 6.271026724791034e-5; 0.00013467644953462935 -0.00010275696490597245 8.874190302859725e-5 8.740655380362374e-5 -4.356213116724294e-5 -5.182271478672357e-6 -4.0325748840667404e-5 0.00010521424178560929 -0.0001328778578218495 -2.589050177365675e-5 0.00016001162498552036 -3.5582453846553356e-5 -8.863168272827264e-5 -0.00010825007279807954 -6.457335780313747e-5 6.013540246468621e-5 5.908856314900572e-5 -7.215230812024618e-5 4.9986984095915074e-5 3.8646072653967155e-5 1.5903673784953802e-5 1.1365303408555662e-5 1.0496149344283441e-5 -3.269227977315225e-5 -6.632880626883817e-5 3.467040415589422e-6 -1.2889939852436523e-5 -4.386438899845141e-5 0.00012264547899659163 -9.095755899496638e-5 3.7456206201485284e-5 8.131758136270188e-5; 9.006218431561352e-5 -0.00014324687537926227 -6.490510452029277e-5 3.259267294008322e-5 0.00010484061465642387 5.3287864719037944e-5 -4.716171417574032e-5 2.0703548201741327e-8 3.5401949278756664e-5 1.0904535773017803e-5 -8.338325744450855e-5 0.00010861463204303164 7.451274985026183e-5 -0.00019335977151247691 1.1343858093759476e-5 -2.328215952918605e-5 -6.85892492784338e-5 4.9361146809971905e-5 -7.610955537718622e-5 4.6427378836525594e-5 3.159622326886261e-5 -0.00014145055785963205 -2.1070566728739333e-5 -1.104604700420049e-5 1.8764837366623244e-5 -4.5927846369585984e-5 -4.86032450881595e-5 0.00014393384742827144 -7.77087089973692e-5 -0.00010974432836708607 -2.3315670770958264e-5 6.888464384057914e-5; 8.686746093994394e-5 -0.00016121672581098545 -9.982045650863591e-5 7.784499192754289e-6 -0.00019406570993046754 -8.975554430540954e-5 -8.851478252941496e-5 4.713817850538025e-5 2.833122869455318e-5 0.0001379189628813406 6.779101329706958e-5 -2.231858763585073e-5 9.292604898186552e-5 1.5834125750958774e-5 -5.1015282789337905e-5 0.00014746513568647505 4.250266590937734e-5 5.02305987346025e-5 0.00010485051405234648 0.00016828927562423488 -8.411176588302027e-5 -4.240387044879717e-5 0.00012649361395128525 -5.238177311289617e-5 -6.172406630972275e-5 3.6915890976926306e-5 5.507893892032485e-6 0.00015948031739648877 5.791980473842392e-6 -5.9785525646615134e-5 -0.00011471386153167304 -5.038605797486129e-5; -0.00016913033693433116 -1.059691027782611e-5 5.098842296413857e-5 2.080934421522305e-5 3.917455972535369e-5 0.00010633323466301087 -0.00018951094904784714 -4.504597388310581e-5 0.00019766467078718303 8.43847700989674e-5 -5.11925833868247e-6 2.3007649298076383e-5 1.683116709593817e-5 5.447327560526145e-5 -3.302339252420783e-5 4.784620971898065e-5 -0.0001543294820921657 9.85638725351512e-5 0.0002074508337782665 0.0001103526118922959 -7.947939236052511e-5 -6.36116129216872e-5 0.00013332986775915903 -4.496811749865501e-5 6.457631018840786e-5 -8.782573792721329e-5 -7.439090047134344e-5 -7.59402570926652e-6 -0.00012380532650147811 -4.0796770980676816e-5 0.00012547245926818233 7.877338423989122e-5; 8.000457244637565e-5 0.00010116993296836479 -1.04038993594844e-5 3.3944046936412497e-6 2.8359359639611762e-5 -6.688522473262557e-5 -9.338631418355041e-5 0.00011384102042942292 -0.00020551659174084167 5.788378410221267e-5 9.38764694001648e-5 -0.00011414766260709898 7.631570559173968e-5 0.00014811532098977714 2.6521329055005614e-5 -5.971774073255701e-5 6.736260336771432e-5 1.2539238632632375e-5 0.0001135302278999347 -0.00018190056500312936 -4.885081063545412e-5 -0.00011408533675416132 -0.00017662428792388299 1.9190319381188715e-6 1.8584285207952064e-5 -0.00013228650852392613 -4.0956323864487426e-5 3.537213495295407e-5 -9.615243681759639e-5 -2.272565536255999e-5 -0.00010598995351750777 -0.00012691028889090072; 1.084937849877452e-5 3.3644809094353925e-5 -4.169620041143767e-5 1.829085778221796e-5 0.00010213077877192158 0.0002121518362523964 -7.566083257775578e-5 -3.4720365134902904e-5 1.4599698091426766e-5 0.00015542003058139148 5.440310889166103e-5 -4.525498769585789e-5 1.4622877473396825e-5 -9.577882627714095e-5 5.883531667999949e-5 3.66699303536232e-5 -3.5196572922777437e-5 -6.319120183111816e-5 2.1296248463427676e-5 -1.1220962902783916e-5 2.425950300263164e-6 -0.0001436849244659776 -1.084230063001034e-6 1.49453979370682e-5 2.0698606761964878e-5 -9.162870070629405e-6 4.2214028156691435e-5 5.767085787151336e-5 -0.00020139133010748812 -9.583298850562105e-5 0.00013028140784948907 -2.869868984494644e-5; -6.383813211472624e-5 -0.00010591736695973103 1.10539462587969e-5 -0.0001459353812202988 1.2512577490595808e-5 0.00015908310837799445 0.00011703842705570176 8.231289434077733e-6 -1.1925876202116029e-5 -5.93259870180595e-5 -1.1506536389462956e-5 -7.236541983404689e-6 0.00021539836547559095 4.655064134963897e-5 -9.789947267167347e-5 1.3507016269029524e-5 0.00011245912850823043 1.0447656172216743e-5 4.221504008852024e-5 2.805104617833264e-5 -5.465005920670307e-6 5.411721253363318e-6 -0.00018144071560131274 2.0219812944886376e-5 -1.0519680977852967e-5 0.00010999075988761981 -2.1440657797613205e-5 0.00011716187822854254 3.2331171900619545e-5 -4.481240517238053e-5 6.320368942793274e-5 -0.0002840351231094635; 0.0001208032026257942 5.3310706752781003e-5 -3.24082630325945e-5 0.00016391608183734168 2.503683568614928e-5 4.3555281044672085e-5 -0.000241645533642762 5.270259676730286e-5 -8.801555854037909e-5 3.896695667177513e-5 -9.238523674899406e-5 -7.222031312299255e-5 -4.0014247500490395e-6 2.1800169597225584e-5 0.00012779412460045436 -0.0001324989080243137 4.873235407982622e-5 -0.00015077892207078467 -2.6917910330294418e-6 5.566375872116515e-5 -8.401168637679949e-5 0.0001243697097012599 -2.038020833457426e-5 -0.0001115170544293152 2.0808929511669214e-5 -0.00013148069596385855 3.00804258219068e-5 -0.00013926258001480297 -0.00017477028635798185 -5.299426815359861e-5 -1.739062467125761e-5 6.112197112056166e-5; -0.00018713428885894177 -0.0001229308314177628 -1.250333008102959e-6 -0.00014574442189405174 -5.618719165797894e-6 -0.00014663702182221073 -0.0001873953356662168 -0.00015730099999564236 0.00019787002515296768 -7.791831608844795e-5 -0.00012628076961482046 5.298142190805026e-5 5.158305561418542e-5 7.725967720398102e-5 -3.5553401681665235e-5 -2.0689731678426002e-5 -1.1635520928363693e-5 -6.536771120954593e-5 9.414908567124105e-6 -8.904279689522985e-5 1.5610924896461764e-5 0.0001491728994042445 -9.70642200227327e-5 0.00014415087338784193 3.6799888624968224e-5 2.4123262341182302e-5 3.6543458412779327e-5 2.712836198041037e-6 4.0385975302362906e-5 7.374470573631538e-5 -3.162086300650979e-5 -0.0001339360956590028; -0.00012459659722357396 -0.00010894051177150934 -5.03568766622654e-5 1.7481558253491446e-5 5.570887203371708e-5 -7.160835178059479e-5 7.654785492500961e-5 -0.000104079866495015 -5.293968518220977e-5 -3.5078315184945944e-5 3.9522158268808e-5 -0.00018916060746836745 6.604669813110285e-5 5.745069626878573e-5 0.00016770929557547755 2.632512275464075e-5 0.0001526102428864409 3.365988283927212e-5 8.1421967999047e-5 1.2524285531396552e-5 -2.1033070359992812e-5 -0.0001463597431471857 0.00012883378128331709 6.988515139833656e-5 -0.00017239627725633986 7.353894813728398e-5 3.065111429473964e-5 -9.31208118605744e-5 -2.981976233844751e-5 0.00010028212324992387 -3.873756161614339e-5 -4.018558995871136e-5; 0.00019081993999811808 -0.0001145806028585432 -0.00010066799692190418 -0.00014332179231182806 -3.425215360089461e-5 -0.0001724340337755378 0.0001049616901941181 -3.8123617887824504e-5 0.00023567198408223913 -0.00017705762835296937 4.366638076545232e-6 -0.00010087277874894927 6.0802844781809466e-5 -0.000294107428235796 3.0235217163374615e-5 -0.0001382205311170284 8.547874935127246e-5 -0.00022978167649903352 -7.190431748363203e-5 -4.573139554900322e-5 0.0001204578573737735 -8.457491745551204e-5 7.04782131283053e-5 6.85168768899961e-5 -0.00018927330635100603 0.00017363651462752827 6.625862337336772e-5 1.2233319598154625e-6 9.333430934444572e-6 8.354928908324607e-5 -8.664604006550811e-7 5.7348245572443233e-5; 4.868047995395558e-5 -2.874866572247668e-5 -2.2539330752723294e-5 -9.921216246303506e-5 -3.966570352068436e-5 -6.834085638267861e-5 3.238734075702038e-5 5.5597349774793915e-6 -8.43228610801058e-5 4.0407530609700005e-5 2.8881518806056004e-5 8.324998066663079e-5 9.894136178907988e-5 -1.5916768684208304e-5 3.0780936645062895e-5 -9.637334573645718e-5 -8.366913811635043e-5 -3.2377773318248106e-5 0.00016774379196712346 3.2126286673785036e-5 -5.704480485738528e-5 0.00022912232976075715 0.0001085771016324471 4.183043870861344e-5 9.886123512010046e-6 -0.0002914551935145192 -8.788807685876994e-6 6.504238422630072e-5 0.0001972987172440128 6.985475352402475e-5 0.00011527470792778122 -0.00015511349760791962; 0.0001352795245849518 2.4967617974802176e-5 -0.00015927600699691167 -5.799809306957832e-5 -0.0001467724193560896 0.00013748828703788495 -2.497707620826738e-5 3.261693771597861e-5 -7.00664001841129e-5 0.0002006729939977597 -5.945785389379718e-5 -5.374955955525109e-5 3.633935764915989e-5 -7.4610040599834855e-6 1.4294566682257296e-5 0.0001540218675301395 2.1794649256807325e-5 1.520663070656498e-5 6.450715798670328e-5 -7.40494485570124e-5 -6.474286436819778e-6 -5.751543150730854e-5 -0.00011045748189538769 6.99592003740364e-5 -8.033654532376944e-5 0.00013461556434882452 9.587510638190353e-5 -0.00014326667380267812 8.170683057065683e-5 -1.0056766512967346e-5 9.97569102870291e-6 -6.386625111084616e-5; 0.00013934343036586246 -0.00015093199004562595 -5.18560323920361e-6 0.0002470917770019354 -0.00015344407261958282 6.038731038773683e-5 0.00012646498538875295 5.8135838063629024e-5 1.2695607384965885e-5 0.00015176621122140466 0.00010786023807801007 4.459793303017826e-5 -0.00014385380751132088 -7.099418695228126e-5 -2.9998447697427487e-5 3.258331156890378e-5 -4.930717349054813e-6 -0.00017067918639308006 2.983639018699354e-5 0.00014575494600782186 -6.81573565621297e-5 -7.34575351621585e-5 7.493901097538008e-5 -0.0002528477107794396 0.00021409529466968815 3.004428430490081e-5 -9.697100000133826e-5 -2.7978596388981347e-5 1.1409285381382952e-5 0.00018513273420600222 0.00021104012006748174 -8.581248954082548e-5; -2.912909348808437e-5 -8.792344214882629e-5 -6.259835845114588e-5 -9.14657348373335e-5 3.721306075319611e-5 4.205053210472267e-5 7.986724904484066e-5 -0.00015018547764768743 0.00012307779343545268 6.874077462571277e-5 -7.776217443890447e-5 -0.0001767369749329343 1.2664096178276455e-5 4.724270638987398e-5 6.210785886950688e-5 -8.448706095976418e-7 -9.182389385086245e-5 8.272295326903711e-5 -4.126931071070694e-5 0.00012890720336442837 -2.31304684318511e-5 -0.0001159440062080637 -0.00012025442174192555 -0.00015957766208587185 0.00010722455863587115 -8.164464724969001e-5 7.104702763470041e-5 -9.540287191769113e-5 9.629242504460985e-5 -0.00014987783560784494 -7.158317657439953e-5 -4.683498172809439e-5; 1.169335131038821e-5 0.00012680783377931463 4.237380902254037e-5 -6.246670115302325e-5 0.00011364085342706261 -2.0904981526271123e-5 -4.271226158977002e-5 2.5086375656831183e-5 5.592876489965895e-6 0.00013802667648035648 -6.956018538912283e-6 8.011949806444866e-5 -1.7729904599550542e-5 -8.566243219916637e-5 -0.00016573542986094973 -6.074011641118049e-5 -3.7347665299035616e-5 0.00012275652237207538 -9.66036243576659e-5 0.00019627713457652407 5.7321223478941924e-5 3.6295310322591616e-5 -9.80811843560932e-6 8.417395996133835e-5 0.00012764759570330555 4.720526679505451e-5 -8.31475411055706e-5 -1.2178999364157856e-6 -5.775716660225051e-5 -0.00012926053047209016 0.00010785125753811874 -5.597283622279268e-5; 6.484755378133945e-5 0.0001536154226554677 -0.00024463496136794535 -0.00011045148911270036 -2.738697056391985e-5 8.991148880834223e-5 -2.826677208266606e-5 1.3821427651611469e-5 -0.00027805835323723127 -3.388089006382826e-5 8.614013442221481e-5 -2.8123030083029866e-5 -0.00023033226223977523 4.377582028192696e-5 -2.1726502913754273e-5 -7.536957908504183e-5 -0.00013361295767442882 -2.478086193684982e-6 6.717161263997918e-6 7.404709996572938e-5 3.340292406884252e-5 -6.915077445644717e-5 4.2602088265458496e-5 0.00023795005783123214 9.412888122407523e-5 2.3925163455186407e-5 -2.6394244364008207e-5 -5.5978013355878617e-5 7.413435324944603e-5 -6.918544439448727e-5 6.052158952608996e-5 9.55283898432886e-5; -1.9062960358652186e-5 6.767249708521315e-5 0.00022506427324135922 6.779981906750367e-5 -0.00018374556254836577 0.00013090814302433798 3.284502001936544e-5 -1.4845748022863236e-5 8.74699702997771e-5 8.932629991779776e-5 -6.465166968534625e-5 -2.3197526910891683e-5 8.634201497063898e-5 -5.741512957153691e-5 -0.00014942260038348996 -3.1767910126447286e-5 4.2624332696344714e-5 6.710574364281366e-5 -5.570943410400199e-5 -4.071798555211952e-5 8.144816949955337e-5 -7.472465092557752e-5 -6.098566767987425e-6 0.0002449843317898097 -0.00016948101854193117 2.141126917724773e-5 -0.00010068844728645747 7.60746710746657e-5 -3.549006084920119e-5 -5.4820988747606515e-5 -0.00020959097180291408 8.987089806926437e-5; -8.813260829641908e-5 -0.00020907313920192125 -0.0001956953161499247 0.0001256275758276914 -4.0410570007356014e-5 -0.00014187742963995542 -6.285385062086109e-5 9.814148906285252e-5 -4.687447982007356e-5 9.96428424685875e-6 -3.7145844187887865e-5 -3.1133636737771004e-5 -4.028603107883761e-5 -3.171914396093039e-5 -8.240272619161398e-5 -2.431991525385059e-5 1.9438990801921996e-6 -1.6989891793313688e-6 -0.0001584032685133391 -2.285150877300406e-5 7.141854111237644e-5 3.1178571060874065e-6 3.946956796295129e-6 -4.223664983159609e-5 3.32954117692445e-5 7.5809224580381864e-6 -9.235904697690468e-6 0.00013910074302826293 -6.943336084815851e-5 -7.133383370104916e-5 0.00012213700792189372 -5.0556701427310776e-6; 0.00026266778816822395 -0.0002800616354459958 -0.00014826959583555033 2.2291225225821697e-5 -4.3454205903650474e-5 8.509508127934427e-5 2.4872411207224143e-5 -5.806517280401343e-5 -5.6929814187993836e-5 2.329731917887119e-5 2.9688671323280167e-5 -4.38464309507538e-5 1.540569845175419e-6 0.00011820046286917558 -0.00013532734071016055 -4.5098881552651055e-5 9.60810787848289e-5 0.00013246320243812573 -1.1666564811852817e-5 8.557124177347228e-5 -0.00012562259568479577 6.382451677543951e-6 5.5941337753359954e-5 -0.00013150701185335112 2.5612923612399853e-5 2.7738856563860562e-5 1.1192202994757172e-5 3.682360996496209e-5 -7.599560255296086e-5 -6.249804634247608e-5 -0.0001444588857946837 -4.220997349587475e-5; 9.981440245117025e-5 0.00015448577334116297 7.702281447070733e-5 -0.0001529317304497672 0.00016595172211471168 0.00017107606264703774 0.00017680286702177045 0.00011826823184908446 5.930864800990528e-5 1.9483923201332645e-5 -0.00011225257174867348 6.854512425498764e-6 2.1085024959181795e-5 9.736409275735872e-5 1.8176557309347424e-5 -0.00021875512164608855 -7.858977904601913e-5 -7.88872056413538e-5 -0.00024902773283013184 -3.137474019908001e-7 1.3105309439722004e-5 -1.0005071895139347e-5 -9.565668723136966e-5 5.1321840250763254e-5 -3.9429876674542597e-5 1.992693079461131e-5 7.30378098030606e-6 9.934355340608527e-5 5.377856370120012e-5 0.00017120186395420798 -5.304946524615215e-5 -5.7890116191110176e-5; 3.36730569456978e-5 -3.980519837059989e-6 4.1486664171824935e-5 -6.12946990672215e-5 -8.089948309629077e-5 6.601893769996899e-5 0.00010419147807129176 2.155075676870133e-5 -7.848987510938303e-5 0.00022588269862766725 -3.9019220226185216e-5 7.079537831901133e-5 2.7068093599655423e-5 -0.00011591745928439174 4.026516913158669e-5 0.00011546652026896056 -4.665373703161345e-5 -5.486983568102169e-5 0.00011882418553301073 0.0002459503427003801 -5.265752619874528e-5 3.3902340559999724e-5 -8.496669247091602e-5 -0.00010232952840437083 1.5726771066101158e-5 2.206363174495452e-5 3.6946293262725816e-6 3.261279712620315e-5 0.00013023306169502048 -0.0001226978368735584 6.506830473384881e-5 -9.758465088948144e-5; -6.970603380451752e-5 -4.175218663860145e-5 -0.00012897669029754614 6.183751730148688e-5 0.00011318414663563338 -0.00015233632684083178 -0.00015150020289566497 9.96509164049567e-5 -6.081405326239999e-5 2.0658940507755776e-5 -1.9159225025043355e-5 1.2853948924923682e-5 -7.512748079007195e-5 -0.00027024725085922645 7.785961980135011e-5 -6.605139305448549e-5 3.488234624984876e-5 0.00019490869208838003 -0.00013964020280965442 0.00010156001854355367 6.324767064669042e-5 -2.948060200123042e-5 -4.26001340149171e-5 -0.0001320343678291615 -5.211597335273918e-5 -6.0146805428254015e-6 -4.4734852685192096e-5 4.298307003935692e-5 1.2711870391109993e-5 -0.0001478164437632735 -1.0791332619360449e-5 9.516873733972974e-5; -0.00021286957004252064 -0.00017868524397402568 8.992483393609962e-5 5.6192010141392615e-5 -0.00021495369534153244 -2.9234751729173967e-5 -0.00018270045207956836 1.2865151382873135e-5 0.00014750351741767934 3.2745765916947515e-6 3.459780831560601e-5 -2.51718097037323e-5 4.7067897447587445e-5 -1.3796086126917432e-6 4.750572819702883e-5 -0.00010123260241145726 -4.8308323549805585e-5 -7.5416922719732e-5 2.244090550318191e-5 -3.874367345036585e-5 -7.528918600786204e-5 -9.121355056903433e-5 -9.788885605433709e-5 5.2028965491050794e-5 -3.2635325273141252e-6 7.89725116679238e-8 0.0001356112411836592 7.60772459566996e-5 -0.00011795728593041889 -0.00016935913978289433 1.2229621041678305e-5 -0.00027071145401479736; 3.265580471503264e-5 1.8565523531432164e-5 -3.492104497258094e-5 -0.00013984134329934108 0.0001802410852165461 -4.3526592891252786e-5 -7.391332383252317e-5 0.00010660553470966871 -7.730317065708998e-5 -9.840718758042834e-5 -5.477158539576887e-6 -1.6765447982318484e-5 2.9300676941213724e-5 3.828570424826241e-5 0.0001166503434420318 0.00016962694735281277 2.4885574929241845e-5 -0.00019714313293792601 -6.296991978291701e-5 -3.2301804835168947e-6 8.160353263388749e-5 -2.2458082641073175e-5 -6.809430397101333e-5 -0.0001760640288934536 -0.000139935363223618 -1.0406807633825965e-5 -9.993562616859024e-5 -0.00011036011619711848 -3.4597774175761686e-5 -5.180301286731424e-5 -0.00012289511414693172 6.965702868780463e-5; -5.882298328993929e-5 -0.00010929211394592271 3.9004878901540374e-5 2.2486281425006774e-5 0.00013199163540069867 0.0001948171428850747 -1.66872779189013e-5 -8.381692724266172e-5 -4.5690418217201814e-5 6.892304544838966e-5 7.857352004283157e-6 6.679787647248154e-5 -5.929274821734154e-5 -0.00021055518128174436 0.00012817830601510492 -9.687197426297793e-5 9.160219041228534e-8 -3.3036032716846966e-5 -1.6303670079301702e-6 -0.00011677846315047786 7.62012149536881e-5 0.00011252429421095998 0.00011171904942988337 -4.205144260391678e-5 -2.6488274768563985e-5 1.109406447743658e-5 -0.00010528417452202262 5.8964354950365697e-5 6.65932474405401e-5 2.043821213395717e-5 -4.7330782861319596e-5 -5.6017308550291703e-5], bias = [1.4872269623449693e-9, 1.422779104351726e-9, 1.7088210674891806e-9, 1.0354371634227448e-9, -2.424148469742783e-9, -5.202127650405498e-10, 3.803969894681222e-10, -1.1397078915637993e-9, 5.960635886442348e-10, 1.247764004602381e-9, -2.3813481361895415e-9, 4.536682773229745e-10, 2.7598103619180754e-10, -2.2064708553446104e-9, -2.814482288801616e-9, -3.3417440396572353e-10, -2.162014928446731e-9, 1.2166811685810749e-9, 3.9497641795865116e-10, 2.060861755799659e-9, -2.6402601417690258e-9, 1.5277086672013824e-9, -9.39130212762763e-10, 4.524061564917468e-10, -3.022803997806722e-9, -1.3830384307125835e-9, 1.6954023313513813e-9, 1.8168562386327903e-9, -2.7018306512385716e-9, -4.590639432384901e-9, -2.7348518803520576e-9, 3.082481130184344e-11]), layer_4 = (weight = [-0.0005569026526867425 -0.0005834173164898069 -0.0006862906502027301 -0.0005449053813403325 -0.0006304611757286252 -0.0008365600639539143 -0.0007193783468289277 -0.0007509587188820877 -0.0007958570088295384 -0.000798536213288154 -0.0006939192720298581 -0.000673664695190933 -0.0007990821348696907 -0.0004935703464203383 -0.0006133259266928992 -0.0007461595614096153 -0.0008010292890788358 -0.0007004997658944188 -0.0006883194331210015 -0.0007636555564584072 -0.0009091077996881857 -0.0005313218680426609 -0.0005727081290532456 -0.0007684045262978893 -0.0007629859015675527 -0.000684091433411392 -0.0007931203208858267 -0.0006406241160967621 -0.0006223091255315281 -0.0006624585548411669 -0.0006785681187618932 -0.0007052456665850079; 0.0002924991047717859 0.000138890308062888 0.00024948023559386403 0.00018720122653124733 0.00019536072877762166 0.0001099842886124294 0.00023709958301711047 8.102802038222588e-5 0.0003267666937288479 0.00014288755462456154 9.743960461701162e-5 0.0002678530153479608 0.00022331673071788514 0.0002979118977942022 0.00033093634309316086 0.0002206251681626384 0.00012145446242611257 0.000387392953196379 0.0002728119919720267 0.0002809026495858466 0.00015375185509655613 0.0002459437823906608 0.00019347975732193328 0.00017161304720433922 0.00013591693968584268 0.00015732972409714175 0.00033629688764053457 0.0003011487460875888 0.0004768074664065827 0.0002616609203763479 0.0004056385142093312 0.00014257637561767756], bias = [-0.000683795039433923, 0.000251342468660997]))Visualizing the Results
Let us now plot the loss over time
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")
lines!(ax, losses; linewidth=4, alpha=0.75)
scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)
fig
endFinally let us visualize the results
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
s3 = scatter!(
ax,
tsteps,
waveform_nn_trained;
marker=:circle,
alpha=0.5,
strokewidth=2,
markersize=12,
)
axislegend(
ax,
[[l1, s1], [l2, s2], [l3, s3]],
["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
position=:lb,
)
fig
endAppendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.12.4
Commit 01a2eadb047 (2026-01-06 16:56 UTC)
Build Info:
Official https://julialang.org release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 4 × AMD EPYC 7763 64-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-18.1.7 (ORCJIT, znver3)
GC: Built with stock GC
Threads: 4 default, 1 interactive, 4 GC (on 4 virtual cores)
Environment:
JULIA_DEBUG = Literate
LD_LIBRARY_PATH =
JULIA_NUM_THREADS = 4
JULIA_CPU_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0This page was generated using Literate.jl.