Training a Neural ODE to Model Gravitational Waveforms
This code is adapted from Astroinformatics/ScientificMachineLearning
The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl
Package Imports
using Lux,
ComponentArrays,
LineSearches,
OrdinaryDiffEqLowOrderRK,
Optimization,
OptimizationOptimJL,
Printf,
Random,
SciMLSensitivity
using CairoMakieDefine some Utility Functions
Tip
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector
function one2two(path, m₁, m₂)
M = m₁ + m₂
r₁ = m₂ / M .* path
r₂ = -m₁ / M .* path
return r₁, r₂
endNext we define a function to perform the change of variables:
@views function soln2orbit(soln, model_params=nothing)
@assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
if size(soln, 1) == 2
χ = soln[1, :]
ϕ = soln[2, :]
@assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
p, M, e = model_params
else
χ = soln[1, :]
ϕ = soln[2, :]
p = soln[3, :]
e = soln[4, :]
end
r = p ./ (1 .+ e .* cos.(χ))
x = r .* cos.(ϕ)
y = r .* sin.(ϕ)
orbit = vcat(x', y')
return orbit
endThis function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d_dt(v::AbstractVector, dt)
a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
b = (v[3:end] .- v[1:(end - 2)]) / 2
c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
return [a; b; c] / dt
endThis function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d2_dt2(v::AbstractVector, dt)
a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
return [a; b; c] / (dt^2)
endNow we define a function to compute the trace-free moment tensor from the orbit
function orbit2tensor(orbit, component, mass=1)
x = orbit[1, :]
y = orbit[2, :]
Ixx = x .^ 2
Iyy = y .^ 2
Ixy = x .* y
trace = Ixx .+ Iyy
if component[1] == 1 && component[2] == 1
tmp = Ixx .- trace ./ 3
elseif component[1] == 2 && component[2] == 2
tmp = Iyy .- trace ./ 3
else
tmp = Ixy
end
return mass .* tmp
end
function h_22_quadrupole_components(dt, orbit, component, mass=1)
mtensor = orbit2tensor(orbit, component, mass)
mtensor_ddot = d2_dt2(mtensor, dt)
return 2 * mtensor_ddot
end
function h_22_quadrupole(dt, orbit, mass=1)
h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
return h11, h12, h22
end
function h_22_strain_one_body(dt::T, orbit) where {T}
h11, h12, h22 = h_22_quadrupole(dt, orbit)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
h11 = h11_1 + h11_2
h12 = h12_1 + h12_2
h22 = h22_1 + h22_2
return h11, h12, h22
end
function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
# compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2
@assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"
h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
@assert mass_ratio ≤ 1 "mass_ratio must be <= 1"
@assert mass_ratio ≥ 0 "mass_ratio must be non-negative"
orbit = soln2orbit(soln, model_params)
if mass_ratio > 0
m₂ = inv(T(1) + mass_ratio)
m₁ = mass_ratio * m₂
orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
return waveform
endSimulating the True Model
RelativisticOrbitModel defines system of odes which describes motion of point like particle in Schwarzschild background, uses
where,
function RelativisticOrbitModel(u, (p, M, e), t)
χ, ϕ = u
numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
denom = sqrt((p - 2)^2 - 4 * e^2)
χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = numer / (M * (p^(3 / 2)) * denom)
return [χ̇, ϕ̇]
end
mass_ratio = 0.0 # test particle
u0 = Float64[π, 0.0] # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4) # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize) # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, eLet's simulate the true model and plot the results using OrdinaryDiffEq.jl
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)
axislegend(ax, [[l, s]], ["Waveform Data"])
fig
endDefining a Neural Network Model
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but in case you do use them, make sure to mark them as const.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,
const nn = Chain(
Base.Fix1(fast_activation, cos),
Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-4.546651f-5; 0.0001507953; 0.00030329518; -0.00012051864; 4.106311f-6; 0.00023631188; -2.1043736f-5; -0.00020412661; -5.2843894f-5; -8.110502f-5; -7.617984f-5; 0.0002107185; -0.00011335984; 1.6292048f-5; 1.8951372f-5; -3.9944876f-5; 0.00012434836; -7.201538f-5; -0.00034512967; -0.00014161703; -1.2197997f-5; -0.00011418634; -6.484027f-5; -7.098756f-5; -1.2202059f-5; -0.0001375938; -0.00012981323; 0.00010036959; -2.9328132f-5; 0.0001692746; -3.527715f-5; -2.3791994f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[0.00012512314 -1.138184f-5 -0.00013871744 8.21164f-5 7.148116f-5 2.2568454f-5 -0.00016186498 1.695307f-5 -0.00012245408 -8.187529f-5 -3.5442543f-5 9.117582f-5 0.00015946913 6.968357f-5 -9.5451855f-5 8.851789f-6 -4.5179328f-5 -7.008094f-5 -3.88497f-5 -3.374108f-6 -2.4020357f-5 0.00011511144 -1.1699557f-6 4.1342737f-5 -6.645462f-6 -6.814703f-5 6.625233f-6 2.2276572f-5 3.4290868f-5 0.00010747008 -5.9162012f-5 -0.00022651354; -0.00014471343 -1.7336076f-6 9.44248f-5 -7.3555944f-5 -0.000104875 4.419731f-5 9.411141f-5 -2.0136817f-5 -1.6749866f-5 0.00012554007 -3.005398f-5 -0.00013721663 5.6776542f-5 -2.8048735f-5 -6.551505f-5 0.00013478729 -1.26693285f-5 7.1571456f-5 -4.5697518f-5 0.00019859175 -1.6064487f-5 -2.5532057f-5 0.00013107146 -2.5900323f-5 0.00013119653 1.323607f-5 -5.1233645f-5 -0.00013723662 8.598883f-5 -9.693085f-5 6.9571906f-5 -3.3955046f-5; 4.9965838f-5 -8.259059f-5 -3.251373f-5 -0.00014279775 6.853895f-5 8.086224f-5 0.00021382805 1.5240221f-6 -7.413689f-5 9.0572816f-5 -0.00018852564 -0.00025935844 3.2370648f-5 -5.4614657f-6 4.6487217f-5 1.7109544f-6 0.00015749891 -1.4034888f-5 -0.00018752489 -6.6646455f-5 2.4874553f-5 -2.8907702f-5 -2.7490903f-5 0.00012952292 0.00012583421 -6.595403f-5 1.9168712f-5 2.8937146f-5 0.00020043195 -0.00020772639 -9.6220814f-5 0.00016212362; 9.946703f-5 1.4894257f-5 -8.6159824f-5 -4.9908945f-6 -3.3505785f-6 -1.4344827f-5 -0.0001268099 -0.0001108475 -1.8682336f-5 0.00017658013 -2.4172426f-5 -6.291777f-5 3.3655742f-5 8.858769f-5 -4.877864f-5 0.00022928262 -5.0825754f-5 -8.182427f-6 -4.1624484f-5 4.6407986f-5 -6.861961f-5 0.0001411 4.712955f-5 1.7432802f-5 -6.456205f-5 -8.2292725f-5 -3.994463f-6 -7.18366f-5 0.00011654017 4.5527424f-5 3.3262488f-5 4.5293233f-5; 0.00027452532 -0.00019054611 -6.0164457f-5 0.00010385283 7.432693f-6 -0.000120292774 -0.00025649965 1.9559513f-5 1.3455766f-5 0.00015023704 0.00021966855 -0.00012646898 -2.708925f-6 -8.092413f-5 -8.1816346f-5 -3.9272522f-6 3.472858f-5 2.5093066f-5 -6.786787f-5 7.125856f-5 -0.00011249322 -5.1124807f-5 -1.9096282f-5 -0.00024754272 -6.0464226f-5 -1.4306457f-5 5.058164f-5 -0.00012648242 0.00024817034 -9.762777f-5 -6.6692635f-5 5.151714f-7; 2.5339112f-5 -5.889725f-5 6.965606f-5 0.00017018589 6.352232f-5 -7.9163074f-5 -0.00010003844 -0.00029769234 1.1910925f-5 8.181753f-6 0.00013033101 -0.00035902747 0.0001152291 2.3526049f-5 -0.0001182503 0.00014632351 -8.125f-6 4.1474752f-5 0.00014118933 -7.7946956f-5 4.286935f-5 -0.00013944696 4.986673f-5 -1.7815488f-5 -0.00017421242 4.111932f-5 -2.3970768f-5 0.00014466791 2.3784498f-5 0.00011407064 2.5822577f-5 0.00010912109; 0.0001433112 -3.0050996f-5 0.00011386904 -6.8907175f-5 -3.688692f-5 5.6376703f-5 -3.2364445f-5 9.541797f-5 1.418102f-5 6.936826f-5 -5.7090597f-6 -3.833685f-5 2.0356422f-6 -0.0001261895 0.00013920991 5.0435705f-5 5.3087537f-5 -0.000133137 -0.00018583382 0.00014467668 9.0031135f-6 0.0001452683 -3.535519f-6 3.3579414f-5 -7.544284f-5 -5.015227f-5 -1.730589f-6 -5.7988032f-5 0.00022301279 -2.6999116f-5 3.718759f-5 -3.2229374f-5; -1.9220926f-5 2.0948259f-5 0.00012828773 -7.457597f-5 -3.8898525f-5 0.00016321895 7.39862f-5 -0.00013319461 -5.085775f-5 1.9201261f-5 -6.534844f-5 5.112269f-5 4.440578f-5 -0.00028162138 0.00015522601 6.574039f-5 3.6255646f-5 -0.00012003082 -7.394136f-5 -4.028839f-6 0.0001334324 0.00012164359 -7.625546f-5 2.0450568f-6 1.1363688f-5 0.000107617976 -1.5103299f-5 -4.0538132f-5 -5.5533343f-5 5.2433443f-5 2.9928624f-5 -5.2874322f-5; -6.38245f-5 -4.9404924f-5 0.00015072197 4.8979246f-6 -2.5729883f-5 6.3054365f-5 -7.123962f-5 -8.3980973f-7 -3.5242378f-5 -3.2833752f-5 -0.00012174236 6.4996944f-5 -0.00015505876 -0.00021400153 5.92035f-6 -7.287582f-5 0.00020910853 -0.00013710247 -3.0359732f-5 0.00012414037 8.87132f-5 -0.00015649537 1.7571483f-5 8.616969f-6 7.8731595f-5 0.00013748383 -0.00021233168 0.00014366422 -2.0302473f-6 8.147188f-5 1.779187f-5 -9.4968236f-5; -0.00023705435 -1.0670325f-5 -5.9659604f-5 0.00016740864 -6.5743384f-6 7.0412214f-5 -6.746032f-5 8.673282f-6 -0.00011305571 6.344926f-5 -7.152593f-5 6.8785135f-5 6.4625325f-5 -5.6792975f-5 -0.00021010774 0.00013023257 -2.9845653f-5 2.618444f-5 2.9135419f-5 6.017661f-5 -7.309883f-5 -8.343754f-6 -3.7232356f-5 0.00010002001 9.4298746f-5 4.304311f-5 4.904241f-5 -5.8867052f-5 -0.00026698926 2.1989397f-5 0.00011046802 -4.9129227f-5; 1.4338503f-5 -2.3251074f-5 -0.00018288972 -0.00016353047 0.00013810213 2.9738885f-5 5.2256204f-5 -3.5457513f-5 -8.557079f-5 -1.6614411f-5 3.590451f-5 -0.00014861763 -6.1847745f-6 3.0728952f-5 -0.00019493574 6.220935f-5 -9.583491f-5 9.7230295f-5 -2.8696804f-5 4.4405664f-5 -2.5392359f-5 4.6162895f-5 -9.3851185f-7 8.4132545f-5 -3.7980233f-5 -0.00024289457 2.178278f-5 5.34527f-5 7.6277756f-5 -0.0001768976 -8.0142476f-5 -9.214103f-5; 6.7420595f-5 -3.3691827f-5 4.3338655f-6 1.2672023f-5 -4.4881028f-5 6.283696f-5 7.53509f-5 -0.00014667254 -0.00020030783 9.7383956f-5 0.000121443634 -1.1795566f-5 -0.00016019438 -8.446925f-6 -9.5486874f-5 -1.3943084f-6 9.408409f-5 0.00011213288 1.9941712f-5 -2.988099f-5 -3.542859f-5 -8.153603f-5 -0.00015637641 0.0001746934 0.00011617578 -8.775881f-5 0.0002449209 -1.1506797f-5 -0.00014346173 5.1335224f-5 -3.258612f-5 -9.179394f-5; 8.079492f-5 6.3096086f-6 -4.637722f-5 8.038112f-5 -6.6715314f-5 7.260162f-5 -4.747449f-5 9.4360556f-5 -0.00011167151 0.00019928944 0.00010419035 3.225792f-5 1.4590078f-6 -0.00015251694 -4.6215548f-5 -5.316211f-6 -0.00012169163 -3.1144682f-5 0.00014234251 3.1870506f-5 1.658734f-5 2.6899414f-5 5.4406395f-5 -6.330219f-5 -2.2112692f-5 -9.129824f-5 -4.343084f-5 -0.00014798672 -0.000115810966 -0.00019511998 -1.4216592f-5 0.000120596414; 8.991588f-5 -0.00010654067 7.771942f-5 0.00011967708 -7.377887f-5 0.00027245778 -1.258142f-6 7.073622f-6 -1.5305242f-5 5.910909f-5 -0.00012513218 -9.845873f-5 -9.299633f-5 0.00014708916 -3.0118066f-5 -0.00013116606 -5.568515f-5 7.426652f-5 0.00013651609 5.0740124f-5 0.00019108928 3.120461f-5 0.000106515145 0.00024471662 -0.00013539876 0.0001007508 -2.997116f-5 9.24513f-5 -5.5401542f-5 6.505688f-5 7.5863f-5 9.927143f-5; -5.5745684f-5 -3.456739f-5 1.6410564f-5 0.00012601646 -4.6960606f-5 -2.9594588f-5 -5.1269685f-6 -5.3013613f-5 6.546998f-5 -0.0001436815 6.134985f-5 -8.527416f-5 -7.480397f-5 -0.00015222597 -9.096648f-5 -7.786002f-5 -6.427903f-5 5.7683233f-6 -0.00026513805 0.0001069611 -0.00013475312 0.0002225354 9.3191724f-5 -6.896249f-5 2.0422773f-5 0.00014477602 1.526077f-5 3.7561687f-5 3.964424f-5 0.00015484843 5.7373632f-5 6.171591f-5; -0.00012714646 -6.4783555f-5 -8.928811f-5 0.00014967727 -4.1756997f-5 3.4196633f-5 0.00018097975 8.0746206f-5 9.6660726f-5 -3.9895982f-5 3.3406806f-5 2.9828965f-5 -8.126209f-5 0.00010045945 0.00014290206 -9.757945f-5 -0.00016410988 2.2544205f-5 -4.7393518f-5 7.077325f-5 9.55977f-5 -5.9750793f-5 -9.537315f-5 6.4335465f-5 -4.2519405f-5 -7.4003474f-6 -0.00012806166 2.2697843f-5 0.0001753778 -9.846544f-5 -8.907514f-6 6.0036306f-5; 0.00012519586 -0.00010883163 -0.00016920303 -0.000103224564 0.00018974552 9.631726f-5 6.58701f-5 8.337083f-6 2.9767834f-5 5.161512f-6 0.00012634667 0.00015095636 -0.00013681888 -3.644763f-5 -0.00014826228 -0.00024429001 0.00011926219 5.3547014f-5 -0.00018586585 -1.7998464f-5 -0.00027581467 -9.23404f-5 9.721545f-5 -3.391685f-5 -0.000149971 -0.00012855213 -0.000143767 -7.186412f-6 -0.00013056728 0.00014791923 -0.00017051786 -6.911844f-5; -4.746923f-6 -5.0005296f-5 7.4497446f-5 3.4209654f-5 -3.594534f-5 -2.1499598f-5 0.00014617544 0.00018360357 6.5395165f-5 -5.346347f-5 -9.756975f-5 2.0746616f-5 -4.5249584f-5 -0.00016944266 -1.1021577f-5 5.94366f-5 0.00024147997 -0.00015704325 6.203123f-5 4.8458365f-5 -0.00020056931 -0.000100085046 0.000108992506 4.0525672f-5 8.278877f-5 -5.987678f-5 2.5120018f-5 -8.183402f-5 -2.3745075f-5 0.000109402536 0.00018268981 -1.0725141f-5; 6.393063f-5 6.224724f-5 4.082015f-5 -8.836764f-5 -2.7623064f-5 1.913123f-5 1.844316f-5 2.0952044f-5 -2.1370859f-5 -5.8737733f-6 -0.0002653849 -1.7964878f-6 7.694794f-5 0.00015270532 4.568158f-6 -0.000128547 -6.447645f-5 -3.8655733f-5 7.724246f-5 -1.5026941f-5 -5.193501f-5 -5.803893f-5 0.00019404964 -3.7057616f-5 -0.000101042366 -7.2525574f-5 -3.0204366f-5 -5.8508886f-6 -4.9742146f-5 -3.7351012f-5 -2.0868607f-5 1.5187796f-5; -0.0003309901 2.2469156f-5 -0.000106660686 3.1520904f-5 -2.2808848f-5 -3.8776434f-5 -4.7095407f-5 1.0191504f-6 0.00012284229 -0.00014040672 -9.3137176f-5 -0.00017813657 8.38426f-5 0.00011180947 3.8836246f-5 -8.9365865f-5 -0.00010795529 8.757193f-5 8.550042f-5 -0.0001501906 6.5833025f-5 -0.000111872476 2.0658374f-5 -0.00012775506 -0.00014073972 9.83035f-6 4.609777f-5 -2.3000386f-5 0.00013201691 0.00025865456 -8.293283f-5 -3.6826747f-5; -8.511909f-6 4.098443f-5 2.0588504f-5 -0.00012842535 -4.9778675f-5 1.2008857f-5 0.0001103903 4.1208976f-5 3.36942f-5 9.1460766f-5 -4.6458157f-5 4.391655f-5 -3.207407f-5 9.899251f-5 4.2863474f-5 0.0001382447 -2.384063f-5 -9.696756f-5 5.962373f-5 -7.972317f-5 -0.00016980198 -0.000134902 -7.9023295f-5 7.051091f-5 2.8481307f-5 7.2352836f-5 -2.9743833f-5 -1.00295765f-5 -9.156013f-5 9.720685f-6 -6.3448197f-6 -9.4722614f-5; -0.0001783374 7.24986f-5 -7.5909584f-6 -0.00012879302 7.3065596f-5 -5.160988f-5 -0.00024210752 0.00013472431 -0.00019445119 -8.626301f-5 -1.7853561f-5 2.1673479f-5 1.2926086f-5 2.7933182f-5 -4.9388163f-5 -0.00024338556 -0.00010821644 -2.1520087f-5 -5.6329594f-7 0.00011067325 3.736336f-5 -0.00013093805 -4.684166f-5 5.9199574f-5 5.681899f-6 6.4414475f-5 -6.9176924f-5 0.00010709549 0.00024265863 0.00010456079 -5.2367905f-5 -5.125445f-5; -6.300238f-5 0.0001233388 3.1136573f-5 0.00013089433 0.00021941902 -5.7304627f-5 -4.120328f-5 -2.867895f-5 5.1882704f-5 -1.4834267f-5 -9.934787f-5 -1.3560869f-5 9.79318f-5 -3.7719674f-5 -0.00014307596 -4.2544296f-5 0.00027347062 9.233153f-5 -8.683469f-5 -7.619386f-5 9.4096256f-5 0.00010267489 -0.00015646736 -0.00011261526 0.00010903413 -6.557414f-5 0.00016595771 9.176643f-5 3.2792173f-5 0.00019088674 6.3428924f-5 9.2908385f-7; -7.71392f-6 -0.00012060382 5.2210315f-5 -7.6008226f-5 0.00018407208 -6.4524145f-5 0.00021418063 0.00010753976 3.829048f-6 8.708427f-5 -2.5539792f-5 -0.00017913409 -4.117301f-5 5.0648996f-5 8.150092f-5 -5.0773844f-5 -6.817829f-5 -6.424326f-5 -0.000104487604 -0.00011433779 -5.2699004f-5 3.4283406f-5 -2.086798f-5 -3.940572f-5 3.5431665f-5 0.00020605062 4.755495f-5 -2.8821254f-5 5.8321715f-5 1.013558f-5 4.531666f-6 -6.356923f-5; -0.00013542759 -0.00016214442 -3.519744f-5 -3.7725258f-5 -2.8713934f-5 1.904861f-5 8.028147f-5 -8.4944906f-5 0.00023885946 -1.4817366f-5 -3.757038f-5 3.7902428f-5 -0.00013046314 6.4452885f-5 9.2755276f-5 4.8163474f-5 -1.0932609f-5 2.3404615f-5 -2.198525f-5 2.5057401f-5 0.00015477504 -0.0002358549 -8.999169f-5 -0.00015542339 -7.080107f-6 -6.325457f-5 -3.8449227f-5 0.00011911809 -7.30597f-5 -0.00019804235 0.00017255674 -0.00010167511; -6.858328f-6 1.1446249f-5 1.2144396f-5 0.00015959033 1.2666465f-5 -2.2219052f-5 -0.0001150333 -2.2627144f-5 6.402978f-5 0.00019426314 -0.00017647703 -0.00012732018 -0.00015875646 -6.515633f-5 -1.3721553f-5 0.00011410149 0.00011270811 4.1780586f-5 6.480492f-5 2.7275715f-5 -6.146436f-5 0.00011902139 1.178445f-5 0.000120353645 -0.00013885554 -8.492119f-5 -0.00010543566 -3.0845316f-5 -0.00011848119 -0.00010899894 5.989894f-5 -1.3518621f-5; -2.531782f-5 0.0001393352 -0.00013836558 8.1230646f-5 -9.2762704f-5 5.3377556f-5 -6.278709f-5 5.914542f-6 -0.000117597214 -1.5706574f-6 4.2985783f-5 4.282178f-5 -1.9676809f-5 -3.3566408f-5 -3.358084f-5 -2.7431575f-5 -6.552553f-5 7.396511f-5 7.9780184f-5 -4.832841f-5 -9.6474476f-5 6.311955f-5 5.911213f-5 -7.227147f-5 -0.00020539363 1.0881374f-5 0.00014481659 -8.471014f-5 -8.215631f-5 -0.00017909706 -2.9382836f-5 5.6805857f-5; 4.5185952f-5 -0.00017602801 -6.92247f-5 -6.608485f-5 2.409656f-5 0.00010425351 -7.2919844f-5 2.777496f-5 -1.5773001f-5 4.1863466f-5 -3.916801f-6 -0.00011063227 -0.00010806729 0.00015736339 8.152249f-5 0.00015480784 -5.1681513f-5 -0.00018314019 0.0002034308 2.3030236f-6 7.32988f-5 -9.054564f-5 0.00014461685 -6.879244f-5 -0.00010567027 0.00025469388 -2.5460537f-5 -8.9216446f-5 -3.2922067f-5 -6.162744f-5 -0.00015518103 -0.00017117111; -0.00016556444 -2.4173763f-5 0.00012328765 -3.6206842f-5 0.00012751667 -5.056871f-5 7.2479335f-5 3.540233f-5 -7.4818316f-5 -3.413868f-5 2.7944056f-5 -3.4151777f-5 7.965317f-5 4.4519482f-5 -8.714767f-5 -5.431062f-5 7.893387f-5 5.1782903f-5 -3.5086684f-6 1.5889827f-5 2.1373626f-5 -0.00018277385 -6.1577026f-5 -4.9798833f-5 3.1185707f-5 -0.00013110145 -3.0886564f-5 -1.0909689f-5 0.00014701445 -0.00023090094 -0.00021540147 0.00010099354; -0.00021765669 -3.075334f-6 2.839364f-5 -0.00019376003 -0.00011610136 3.3358683f-5 -1.4742308f-5 -7.1715636f-5 5.1706753f-5 3.121721f-5 -8.433894f-5 -0.00014374142 -1.0415158f-5 6.43291f-5 0.00010561394 2.3201539f-5 -4.0463954f-5 -3.8688853f-5 -0.00015453961 -0.000106191794 -4.4328455f-5 8.8205525f-5 -5.6206663f-5 0.00011203945 3.3245866f-5 -4.3433825f-5 -5.3716532f-5 -8.190654f-5 7.116471f-5 -5.415788f-5 6.257613f-5 -4.9164963f-5; -6.8391106f-5 0.00017064222 3.6451045f-6 -0.000112743786 6.204816f-5 1.0417306f-5 1.2050316f-5 -4.895083f-6 0.00010423775 -5.5721095f-5 -2.6778403f-5 -0.0001406065 -2.8178983f-5 -5.064203f-5 -1.17621685f-5 -8.7436034f-5 3.8636736f-5 4.3467004f-5 -5.7020283f-5 -4.0207233f-5 -4.9545823f-5 -0.00025426253 -0.00014923085 0.00021753393 2.7634456f-5 9.234795f-5 6.225985f-5 5.403391f-5 0.00018554182 0.00012018785 3.1415173f-5 -6.597466f-5; 8.919621f-5 -0.00012824408 -2.9573144f-5 2.3710929f-5 -0.00010161087 -4.4250453f-5 -7.267246f-5 -2.322108f-6 -5.6669036f-5 0.00014620213 -4.1714422f-5 -0.00013548874 6.377606f-6 -3.2604366f-5 -0.00016177312 9.526466f-5 9.203115f-6 -0.00026051275 -4.9440547f-5 0.0002880287 -1.1538938f-5 4.0821382f-5 9.092951f-5 0.00010130428 0.00016379729 -4.3984634f-5 3.1322514f-5 4.4722503f-5 7.4404685f-5 5.9871027f-6 -0.00010049242 -0.00013710215], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[-0.00016273664 9.140447f-5 -9.855616f-6 2.1919313f-5 -5.2690255f-5 7.817307f-6 -8.521767f-5 -9.237643f-5 0.00011282777 3.5962155f-5 4.3586555f-5 5.1593972f-5 0.00015281963 -2.68417f-5 5.962152f-6 0.00017289691 -2.1787446f-5 -0.00014516796 0.00011712398 6.323398f-5 -0.00010847028 5.9071328f-5 -4.8592974f-5 -0.00012144355 -0.00017549549 1.8417351f-5 -7.963612f-5 -0.000111322006 0.00013897686 0.00010311958 -9.100808f-5 -8.946833f-5; -0.00020687417 -1.4872443f-5 0.00010966463 7.151487f-5 6.755007f-5 -4.4954537f-5 4.920469f-5 5.2959032f-5 -7.900568f-5 4.880603f-6 -8.0662474f-5 -1.783372f-5 4.0894593f-6 2.2259144f-5 1.3730008f-6 7.840726f-5 -8.320395f-5 0.00014292424 -2.7645665f-5 -0.00010518634 0.00021784467 -0.000117042575 -7.842627f-5 4.8180416f-5 -1.1152664f-5 7.031355f-5 -7.4306125f-5 -0.00013508061 0.0002976071 3.3747412f-5 0.0001527462 9.037893f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))Similar to most deep learning frameworks, Lux defaults to using Float32. However, in this case we need Float64
const params = ComponentArray(f64(ps))
const nn_model = StatefulLuxLayer(nn, nothing, st)StatefulLuxLayer{Val{true}()}(
Chain(
layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
layer_2 = Dense(1 => 32, cos), # 64 parameters
layer_3 = Dense(32 => 32, cos), # 1_056 parameters
layer_4 = Dense(32 => 2), # 66 parameters
),
) # Total: 1_186 parameters,
# plus 0 states.Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses
where,
function ODE_model(u, nn_params, t)
χ, ϕ = u
p, M, e = ode_model_params
# In this example we know that `st` is am empty NamedTuple hence we can safely ignore
# it, however, in general, we should use `st` to store the state of the neural network.
y = 1 .+ nn_model([first(u)], nn_params)
numer = (1 + e * cos(χ))^2
denom = M * (p^(3 / 2))
χ̇ = (numer / denom) * y[1]
ϕ̇ = (numer / denom) * y[2]
return [χ̇, ϕ̇]
endLet us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
axislegend(
ax,
[[l1, s1], [l2, s2]],
["Waveform Data", "Waveform Neural Net (Untrained)"];
position=:lb,
)
fig
endSetting Up for Training the Neural Network
Next, we define the objective (loss) function to be minimized when training the neural differential equations.
const mseloss = MSELoss()
function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
return mseloss(pred_waveform, waveform)
endWarmup the loss function
loss(params)0.000717775506285655Now let us define a callback function to store the loss over time
const losses = Float64[]
function callback(θ, l)
push!(losses, l)
@printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
return false
endTraining the Neural Network
Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
optprob,
BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
callback,
maxiters=1000,
)retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-4.546650961851169e-5; 0.00015079529839558448; 0.0003032951790368662; -0.00012051864177902437; 4.106310825592823e-6; 0.000236311883781955; -2.1043735614495612e-5; -0.000204126612516098; -5.284389408186739e-5; -8.110501948975564e-5; -7.617983646928387e-5; 0.00021071849914727737; -0.00011335984163446386; 1.62920478032941e-5; 1.895137211248133e-5; -3.9944876334600185e-5; 0.00012434835662123513; -7.201537664514633e-5; -0.0003451296652199775; -0.0001416170271113212; -1.2197997421007822e-5; -0.0001141863394876262; -6.484027107946154e-5; -7.098756032057957e-5; -1.2202059224353523e-5; -0.00013759380090036204; -0.00012981322652185482; 0.00010036958701671689; -2.9328131859079544e-5; 0.00016927460092100598; -3.5277149436239195e-5; -2.3791993953608492e-5;;], bias = [-5.68203821701402e-17, 1.055962810622561e-16, 5.752507013088124e-16, -1.9374869307862354e-16, -1.7776470538091363e-18, 8.76571456571026e-17, -1.8386375209895245e-17, -2.4254601630350623e-17, -2.3435222675097744e-17, -1.5242244283921668e-17, 5.282133755971996e-17, 5.938397727759837e-17, -6.320963695291119e-17, -4.433186042470579e-18, 1.7038463655761762e-17, -3.313172807434013e-17, 1.1819455553963292e-16, 6.392651209415518e-18, -1.9643464197616142e-16, -3.289404748646528e-17, -2.046212406139071e-17, -1.7896422243439814e-16, 1.6931627174809285e-17, -7.331458519212095e-17, -1.2245549426079331e-17, -1.5590604055907467e-16, -1.0218053281405347e-16, 4.473184270882644e-17, 5.422149748080283e-18, 2.1121530035282708e-16, -6.383762747759587e-17, -3.1884246086087485e-17]), layer_3 = (weight = [0.00012512232681674709 -1.1382657040114077e-5 -0.00013871825901385065 8.211558199891974e-5 7.148034385809444e-5 2.2567636942237942e-5 -0.00016186580139273998 1.6952252944441823e-5 -0.00012245489303549907 -8.187610616969e-5 -3.5443360293899505e-5 9.117500051562683e-5 0.00015946831011152003 6.968275304589198e-5 -9.545267200391961e-5 8.850971634639874e-6 -4.518014499862501e-5 -7.00817570492262e-5 -3.885051662160895e-5 -3.3749254458530727e-6 -2.40211745864129e-5 0.00011511062369154879 -1.1707731078760424e-6 4.134191952234968e-5 -6.646279207960407e-6 -6.81478438951607e-5 6.624415372613399e-6 2.2275754626575277e-5 3.4290050488745575e-5 0.0001074692602100276 -5.9162829321028775e-5 -0.00022651435826501093; -0.00014471270886086057 -1.7328892040354282e-6 9.44255206563848e-5 -7.355522582609624e-5 -0.00010487427978612566 4.419802819075552e-5 9.411213061004784e-5 -2.0136098173215273e-5 -1.6749147179761517e-5 0.00012554078959557926 -3.005326114314825e-5 -0.00013721591138118124 5.677726032492325e-5 -2.8048016319625044e-5 -6.551433223601523e-5 0.0001347880078542593 -1.2668610069779425e-5 7.157217434771157e-5 -4.5696799206333645e-5 0.0001985924663317894 -1.606376834340946e-5 -2.5531338978436897e-5 0.00013107217630069524 -2.5899604297119323e-5 0.00013119725001208398 1.3236788545797493e-5 -5.12329261978111e-5 -0.00013723590571269707 8.598954984383555e-5 -9.69301346044503e-5 6.957262478001305e-5 -3.395432766363121e-5; 4.9965663613569614e-5 -8.259076263796918e-5 -3.251390432281409e-5 -0.0001427979294633908 6.853877869858829e-5 8.086206810973393e-5 0.00021382787975441156 1.523847579543545e-6 -7.413706759265185e-5 9.057264117715407e-5 -0.00018852580966935038 -0.00025935861037154005 3.237047361168575e-5 -5.4616401642634916e-6 4.648704283014365e-5 1.7107798909491526e-6 0.00015749874012969012 -1.4035062911112457e-5 -0.000187525059907174 -6.664662929991655e-5 2.487437826190568e-5 -2.8907876064600115e-5 -2.749107795110457e-5 0.0001295227413108152 0.0001258340345272371 -6.595420551760524e-5 1.9168537740065316e-5 2.8936971041425273e-5 0.00020043177949688823 -0.00020772656704806135 -9.622098865960295e-5 0.00016212344065267087; 9.946796685792153e-5 1.4895190498933213e-5 -8.615889041585187e-5 -4.989961188113383e-6 -3.3496452019579e-6 -1.4343893643395233e-5 -0.00012680896711433608 -0.00011084656384672654 -1.8681403024381977e-5 0.00017658106130649773 -2.417149321306962e-5 -6.29168333475641e-5 3.365667557911269e-5 8.858862033914476e-5 -4.877770502250943e-5 0.00022928355806609276 -5.0824821163116136e-5 -8.181494057206249e-6 -4.162355057680205e-5 4.6408919071397174e-5 -6.861867391169366e-5 0.00014110093084265725 4.713048306395813e-5 1.7433734821244374e-5 -6.456111973282245e-5 -8.229179170346749e-5 -3.993529706788085e-6 -7.18356658112317e-5 0.00011654110326300991 4.552835721507445e-5 3.3263420984006e-5 4.52941659673552e-5; 0.00027452315480106047 -0.0001905482706228825 -6.0166617188237034e-5 0.00010385066779920308 7.430532875801181e-6 -0.00012029493446101857 -0.00025650181414901856 1.9557352786748476e-5 1.3453605604370183e-5 0.00015023488303770162 0.00021966639251106218 -0.00012647113835778377 -2.71108520107635e-6 -8.092629214159804e-5 -8.181850644403721e-5 -3.929412471264683e-6 3.4726418377910827e-5 2.509090590821799e-5 -6.787003003627715e-5 7.125640040388195e-5 -0.00011249537710908468 -5.1126967448430826e-5 -1.909844213740283e-5 -0.0002475448774953118 -6.046638664200569e-5 -1.4308617394335852e-5 5.0579478554507834e-5 -0.00012648458432747888 0.00024816818113096445 -9.76299335911451e-5 -6.669479553464389e-5 5.130111594695901e-7; 2.5339280484889644e-5 -5.889708196936007e-5 6.965622881817422e-5 0.0001701860542494732 6.35224874100584e-5 -7.916290493757477e-5 -0.00010003827489289163 -0.0002976921736875636 1.1911094163782405e-5 8.181922141933119e-6 0.00013033118151053118 -0.00035902730311808124 0.00011522926937015019 2.352621734882064e-5 -0.00011825012776842982 0.00014632367813884566 -8.124831281190519e-6 4.147492066532638e-5 0.0001411895005117973 -7.794678683002297e-5 4.2869518013152425e-5 -0.00013944678946007754 4.986689740018258e-5 -1.781531906761896e-5 -0.00017421224626321986 4.111949013587277e-5 -2.397059912926782e-5 0.00014466807763948342 2.3784666639234017e-5 0.00011407080595073425 2.5822746040479337e-5 0.00010912125574815378; 0.00014331281162397808 -3.004938954307043e-5 0.00011387064913813796 -6.890556797786466e-5 -3.688531179265143e-5 5.63783094456126e-5 -3.2362838474149344e-5 9.54195773281869e-5 1.4182626536999075e-5 6.936986645981302e-5 -5.707452792345775e-6 -3.8335242798156205e-5 2.037249068807053e-6 -0.00012618788715386796 0.00013921151438793126 5.043731179269855e-5 5.308914368454469e-5 -0.0001331353934894345 -0.0001858322150021483 0.00014467829058942245 9.004720404697272e-6 0.00014526991325495062 -3.533912113349163e-6 3.3581020752384256e-5 -7.54412332845855e-5 -5.0150663091857775e-5 -1.7289821336572545e-6 -5.798642560891424e-5 0.00022301439347742057 -2.699750913068772e-5 3.718919723709026e-5 -3.222776759699972e-5; -1.9220492330597923e-5 2.0948692682327377e-5 0.00012828816869924978 -7.457553402589272e-5 -3.889809050549186e-5 0.00016321938636876422 7.398663473480013e-5 -0.00013319417740520107 -5.0857316463472626e-5 1.9201695241397342e-5 -6.534800450851656e-5 5.1123123913313616e-5 4.4406214502420405e-5 -0.00028162094787064385 0.00015522644133937174 6.574082107447634e-5 3.625607968472294e-5 -0.00012003038392237826 -7.394092500279837e-5 -4.028405154781128e-6 0.0001334328382571843 0.00012164402614046272 -7.6255028770058e-5 2.0454908198441424e-6 1.1364121920385927e-5 0.00010761841003015937 -1.5102865117181982e-5 -4.053769845001905e-5 -5.553290867198635e-5 5.243387676357455e-5 2.9929057919551524e-5 -5.287388813401014e-5; -6.382559910800442e-5 -4.940602627267462e-5 0.00015072086865167386 4.896821911629107e-6 -2.57309860617362e-5 6.305326281199137e-5 -7.124072396347489e-5 -8.409123725004469e-7 -3.5243480278059044e-5 -3.283485454541604e-5 -0.00012174346590565518 6.49958416315677e-5 -0.00015505986481822778 -0.0002140026302813758 5.919247206274118e-6 -7.287691948387923e-5 0.00020910743081093178 -0.00013710357267793287 -3.0360834324687707e-5 0.00012413926345378664 8.871209468541466e-5 -0.0001564964735453178 1.7570380469788226e-5 8.615866709970592e-6 7.873049219686337e-5 0.0001374827294339286 -0.00021233278345699963 0.00014366311700631087 -2.031349930822042e-6 8.147077790544242e-5 1.7790767406931395e-5 -9.496933886932396e-5; -0.0002370553151701567 -1.0671288196758343e-5 -5.966056729053975e-5 0.00016740767436596124 -6.57530195502872e-6 7.04112504029824e-5 -6.746128515778055e-5 8.672318057799093e-6 -0.00011305667017558463 6.344829720938825e-5 -7.152689382172897e-5 6.878417165746472e-5 6.462436117027394e-5 -5.6793938216007756e-5 -0.00021010870428262227 0.00013023160549374384 -2.984661610361133e-5 2.6183476206196627e-5 2.913445545741373e-5 6.0175647872087676e-5 -7.309979578684969e-5 -8.344717985608345e-6 -3.7233319249654335e-5 0.00010001905006384676 9.429778269287245e-5 4.304214648465324e-5 4.904144638595735e-5 -5.886801541754622e-5 -0.00026699022005749793 2.198843371326569e-5 0.00011046705598738125 -4.913019025088229e-5; 1.4335181521483282e-5 -2.325439548333198e-5 -0.00018289304373933545 -0.00016353378690003853 0.00013809880671865183 2.973556344769682e-5 5.225288238684572e-5 -3.546083483415986e-5 -8.557411295519033e-5 -1.6617732608978426e-5 3.590118992101898e-5 -0.00014862095507014582 -6.188096204675113e-6 3.072563028504046e-5 -0.00019493906095781446 6.220602860516917e-5 -9.583823360192924e-5 9.722697338398047e-5 -2.8700125977331532e-5 4.440234239989116e-5 -2.5395680705379833e-5 4.6159573475200334e-5 -9.418335306208721e-7 8.412922376842652e-5 -3.7983554858193874e-5 -0.00024289789528741465 2.1779458410644916e-5 5.344937723672039e-5 7.627443462219992e-5 -0.00017690092884624234 -8.014579746563907e-5 -9.214435361250365e-5; 6.742013353681091e-5 -3.369228873458486e-5 4.333404167436068e-6 1.2671562008050647e-5 -4.488148918432314e-5 6.283649851861176e-5 7.535043984711045e-5 -0.00014667299840772033 -0.00020030828748537316 9.738349469438161e-5 0.00012144317294763356 -1.1796027347426241e-5 -0.00016019484176473634 -8.447386089816106e-6 -9.548733508387708e-5 -1.394769678941728e-6 9.408362963762233e-5 0.0001121324210270675 1.9941250475270287e-5 -2.988145136447457e-5 -3.5429052541136165e-5 -8.153649398997363e-5 -0.00015637687031818623 0.00017469293813513536 0.00011617532142730298 -8.775927129141835e-5 0.00024492043737148975 -1.1507258232144923e-5 -0.00014346219107216345 5.1334762947534346e-5 -3.258657964585462e-5 -9.17944009689254e-5; 8.079389664146813e-5 6.308587044835754e-6 -4.637824094001288e-5 8.038009838003835e-5 -6.67163359952376e-5 7.260060084322129e-5 -4.747550991794386e-5 9.435953398737254e-5 -0.00011167252813663547 0.00019928841794410632 0.00010418933089624975 3.225690016792373e-5 1.4579862278245696e-6 -0.00015251795702113133 -4.621656916185577e-5 -5.3172327073968305e-6 -0.00012169264954479211 -3.1145703532613325e-5 0.0001423414913103158 3.1869484528790946e-5 1.65863191892875e-5 2.6898392034700414e-5 5.44053735308146e-5 -6.330321338225296e-5 -2.211371355713181e-5 -9.129926474030746e-5 -4.343186009406797e-5 -0.00014798774208372434 -0.00011581198777037021 -0.00019512100260668736 -1.4217613342565345e-5 0.00012059539295535245; 8.99200537372506e-5 -0.00010653649164458409 7.772359881911715e-5 0.00011968125968845136 -7.377469532516184e-5 0.00027246195591421377 -1.2539649975658656e-6 7.077798867877404e-6 -1.5301065419090502e-5 5.9113267211017156e-5 -0.00012512800398840752 -9.845455437763397e-5 -9.2992151026616e-5 0.0001470933354124577 -3.011388921154107e-5 -0.00013116188287826794 -5.568097125948959e-5 7.427069854597464e-5 0.0001365202630533313 5.0744300692097666e-5 0.00019109345217143772 3.1208787417502685e-5 0.00010651932157483921 0.0002447207959033634 -0.00013539458390882583 0.00010075498053477056 -2.996698217036477e-5 9.245547758322383e-5 -5.539736534560544e-5 6.506106060834131e-5 7.586717461376885e-5 9.927560372884913e-5; -5.574627703335541e-5 -3.4567983496701433e-5 1.6409971246159756e-5 0.00012601586440414954 -4.6961198498499355e-5 -2.9595181152567543e-5 -5.1275613290243095e-6 -5.301420588092524e-5 6.546939057134633e-5 -0.00014368208558461948 6.13492558810361e-5 -8.52747505663388e-5 -7.480456024804351e-5 -0.00015222656275309132 -9.096707417793078e-5 -7.786061524111259e-5 -6.427962562465257e-5 5.7677304508653915e-6 -0.00026513864195601113 0.00010696050976240016 -0.00013475371434051056 0.00022253480364871017 9.319113105139334e-5 -6.896307950537067e-5 2.0422179838172408e-5 0.00014477542226505385 1.5260177130227626e-5 3.7561094538073335e-5 3.964364823023361e-5 0.00015484783598058963 5.737303963233296e-5 6.171531658456228e-5; -0.00012714579829826734 -6.478289619900505e-5 -8.928744832867582e-5 0.0001496779335395242 -4.1756338443343715e-5 3.4197292141635066e-5 0.00018098041286811482 8.074686443086092e-5 9.66613846214322e-5 -3.989532310861973e-5 3.340746511475322e-5 2.9829623916323955e-5 -8.126143407190229e-5 0.00010046011116009461 0.0001429027218802671 -9.757878765577225e-5 -0.00016410922003325142 2.2544864212425072e-5 -4.739285909803051e-5 7.077390577995986e-5 9.559835993804296e-5 -5.975013450716637e-5 -9.537249175246635e-5 6.433612393101831e-5 -4.251874602396715e-5 -7.399688594715209e-6 -0.0001280609973508129 2.2698501514415732e-5 0.00017537846392605695 -9.846478101445708e-5 -8.906855011270366e-6 6.003696430361147e-5; 0.00012519147367999185 -0.00010883601406967843 -0.00016920742108622667 -0.0001032289501294295 0.00018974113573636797 9.631287049128117e-5 6.586571635995057e-5 8.332696788443753e-6 2.9763447864356945e-5 5.157125551314583e-6 0.00012634228279200202 0.00015095197250923764 -0.00013682326879499592 -3.645201558396249e-5 -0.00014826666209995763 -0.0002442944014459833 0.00011925780655436475 5.35426270376492e-5 -0.0001858702371375821 -1.8002850074852132e-5 -0.00027581905941627963 -9.23447869267137e-5 9.72110656308063e-5 -3.392123743690833e-5 -0.00014997539164181842 -0.00012855651492865011 -0.0001437713863109938 -7.190798433315163e-6 -0.00013057166414949986 0.0001479148422819029 -0.0001705222448348988 -6.912282605338828e-5; -4.745572396716783e-6 -5.000394506228343e-5 7.449879646566254e-5 3.421100424871024e-5 -3.594399007001078e-5 -2.1498247434164076e-5 0.00014617679423223112 0.0001836049168280969 6.539651528269745e-5 -5.346212040480633e-5 -9.756839702335558e-5 2.074796700011264e-5 -4.52482336464965e-5 -0.00016944131041840782 -1.102022636645606e-5 5.94379498300962e-5 0.0002414813179433173 -0.00015704189993870788 6.203257814313512e-5 4.84597157713764e-5 -0.00020056795895529104 -0.00010008369557058199 0.00010899385685602831 4.052702294698578e-5 8.279012124090333e-5 -5.987542957032118e-5 2.5121368680038283e-5 -8.183266893924776e-5 -2.3743724369198704e-5 0.00010940388617140726 0.00018269115841517986 -1.0723790210833795e-5; 6.392915094596476e-5 6.224576356448429e-5 4.081866865725016e-5 -8.836912018619674e-5 -2.7624543855441274e-5 1.912975023420898e-5 1.8441679293513958e-5 2.0950564084119856e-5 -2.1372338943423785e-5 -5.875253174110466e-6 -0.00026538638774678726 -1.79796773503999e-6 7.694645757414125e-5 0.00015270384466833494 4.566678234309112e-6 -0.00012854848608282273 -6.447793128341586e-5 -3.8657212991040996e-5 7.724098106245629e-5 -1.5028420585672639e-5 -5.193649047311342e-5 -5.8040411368946414e-5 0.0001940481641385636 -3.7059096194810186e-5 -0.00010104384562606872 -7.252705398155363e-5 -3.0205846252153554e-5 -5.852368468422796e-6 -4.9743625987562426e-5 -3.7352491909633284e-5 -2.0870086865275056e-5 1.5186315812511159e-5; -0.000330992877567257 2.2466371232242966e-5 -0.00010666347006567649 3.1518119070620096e-5 -2.2811632876472653e-5 -3.8779218444855455e-5 -4.7098191859039203e-5 1.0163659260815318e-6 0.00012283950709730904 -0.0001404094997235302 -9.313996051444817e-5 -0.00017813935628164046 8.383981403655857e-5 0.0001118066888045735 3.8833461443356445e-5 -8.936864978408032e-5 -0.00010795807392807894 8.75691424914895e-5 8.549763370308762e-5 -0.00015019337806390398 6.583024005864096e-5 -0.00011187526033265946 2.0655589108869344e-5 -0.000127757846345261 -0.0001407425057516219 9.82756537612624e-6 4.609498719562687e-5 -2.300317064166764e-5 0.00013201412585091193 0.00025865177522408804 -8.293561185950035e-5 -3.682953101473543e-5; -8.512531833369655e-6 4.098380601808579e-5 2.0587881651787755e-5 -0.00012842597491962852 -4.977929778088196e-5 1.2008234323192234e-5 0.00011038967501680995 4.120835298399976e-5 3.3693576613038096e-5 9.146014351008281e-5 -4.6458779918432315e-5 4.3915925987135936e-5 -3.207469192839278e-5 9.899188741419158e-5 4.286285117577457e-5 0.0001382440707560787 -2.3841252852980955e-5 -9.696818299981041e-5 5.962310685156159e-5 -7.972379238035435e-5 -0.00016980259964166516 -0.00013490262493414851 -7.902391801744927e-5 7.05102895766662e-5 2.848068403854658e-5 7.235221279861478e-5 -2.9744455544096113e-5 -1.00301992757536e-5 -9.156074950861107e-5 9.72006206884785e-6 -6.345442427667395e-6 -9.472323720965835e-5; -0.00017833977948816503 7.249622536416802e-5 -7.593335105566251e-6 -0.00012879539757908314 7.306321891315117e-5 -5.161225748652224e-5 -0.00024210989544576168 0.0001347219320066761 -0.00019445356632990715 -8.626538805605506e-5 -1.785593780572715e-5 2.1671102111022528e-5 1.2923709398604687e-5 2.7930804897326448e-5 -4.939054019109485e-5 -0.00024338793195260734 -0.00010821881482318894 -2.1522463832518585e-5 -5.65672671334179e-7 0.0001106708757631803 3.7360982567182126e-5 -0.00013094042264332167 -4.684403506166147e-5 5.919719729514079e-5 5.6795222533339035e-6 6.441209803394387e-5 -6.917930044640599e-5 0.00010709311278129598 0.00024265625211933526 0.00010455841653108755 -5.237028130272475e-5 -5.125682695711155e-5; -6.299918137954779e-5 0.00012334200030890722 3.1139770301448103e-5 0.00013089752931846657 0.0002194222138612716 -5.7301429903639515e-5 -4.120008131138112e-5 -2.8675751734433073e-5 5.18859010597381e-5 -1.4831069590570376e-5 -9.93446705862755e-5 -1.355767150674189e-5 9.793499660445897e-5 -3.771647649698934e-5 -0.00014307276111792823 -4.254109850733723e-5 0.000273473818681671 9.233472840861796e-5 -8.683149323343495e-5 -7.619065988110603e-5 9.409945372028244e-5 0.00010267808423378075 -0.0001564641611068368 -0.00011261205915360317 0.00010903732939625126 -6.557094135787447e-5 0.0001659609074820775 9.176962660859216e-5 3.279537080090388e-5 0.00019088993897893288 6.343212183566617e-5 9.322812118804451e-7; -7.71371230770494e-6 -0.00012060361348234615 5.22105225394197e-5 -7.600801780030574e-5 0.00018407228765168816 -6.452393734038597e-5 0.00021418084054344917 0.00010753996741582945 3.829256115167956e-6 8.708447842251121e-5 -2.553958381726325e-5 -0.0001791338830804703 -4.117280275353662e-5 5.064920389882874e-5 8.15011268281157e-5 -5.077363619024289e-5 -6.817808332529368e-5 -6.424304899668699e-5 -0.00010448739647076684 -0.00011433758061288404 -5.269879636729482e-5 3.428361435395232e-5 -2.086777146872972e-5 -3.9405512666817825e-5 3.543187324281436e-5 0.00020605083102451165 4.755515749307822e-5 -2.8821046158228475e-5 5.832192311969235e-5 1.013578787420338e-5 4.531873968844262e-6 -6.356901883522396e-5; -0.00013542977914491626 -0.0001621466152909806 -3.519963221779171e-5 -3.772744905014525e-5 -2.8716124613286118e-5 1.9046418931858188e-5 8.027927750211135e-5 -8.49470970178482e-5 0.00023885726758001466 -1.4819556706213499e-5 -3.757256937839715e-5 3.790023682364524e-5 -0.0001304653350570903 6.445069364692926e-5 9.275308464841836e-5 4.816128337827518e-5 -1.0934800154209904e-5 2.3402423797080126e-5 -2.198744088389454e-5 2.505521031984518e-5 0.00015477285258660002 -0.00023585708633341938 -8.999388311801284e-5 -0.00015542558041212824 -7.082297922327208e-6 -6.325676160209501e-5 -3.845141774669128e-5 0.00011911590180082857 -7.306189301468475e-5 -0.00019804454579293657 0.00017255455084798843 -0.00010167730298869377; -6.859271082495182e-6 1.1445306142137092e-5 1.214345337462783e-5 0.00015958938746437804 1.2665522432806204e-5 -2.2219994946911448e-5 -0.00011503424031004367 -2.2628087038590688e-5 6.402883735675379e-5 0.00019426219222325964 -0.00017647797068390887 -0.00012732112082927138 -0.00015875740501839348 -6.5157274378427e-5 -1.3722495908288412e-5 0.00011410054429391377 0.00011270716930697141 4.177964291259733e-5 6.480397422525865e-5 2.7274772207178928e-5 -6.146530068988046e-5 0.0001190204449692215 1.1783507203953917e-5 0.0001203527019122065 -0.00013885648223847216 -8.492212947500365e-5 -0.00010543660465670307 -3.0846259335947644e-5 -0.00011848213472446472 -0.00010899988120696637 5.989799609481324e-5 -1.3519563993460567e-5; -2.5319964528879488e-5 0.00013933305543617903 -0.00013836772033238087 8.122850211805655e-5 -9.276484835425473e-5 5.337541241320841e-5 -6.278923448406214e-5 5.912397760449756e-6 -0.0001175993578857545 -1.5728015017661492e-6 4.298363872805928e-5 4.281963500548352e-5 -1.9678952789373085e-5 -3.356855214804477e-5 -3.3582984009972434e-5 -2.7433719346280875e-5 -6.552767209938671e-5 7.396296367517016e-5 7.977803996774154e-5 -4.833055337654983e-5 -9.647661992346941e-5 6.311740670355855e-5 5.910998751062175e-5 -7.227361258117217e-5 -0.00020539577727902755 1.0879229726591776e-5 0.00014481444176056576 -8.471228412815687e-5 -8.215845210941534e-5 -0.00017909920044472405 -2.9384980202284623e-5 5.680371267373629e-5; 4.5184646909095445e-5 -0.000176029319233282 -6.922600307594165e-5 -6.60861581550224e-5 2.4095254660553477e-5 0.00010425220232130258 -7.292114908204656e-5 2.7773655050188344e-5 -1.5774306417821494e-5 4.186216090011425e-5 -3.918106200667278e-6 -0.00011063357244572854 -0.00010806859365992899 0.00015736208722364842 8.152118668795316e-5 0.0001548065380789041 -5.16828179811368e-5 -0.00018314149504157756 0.0002034294927157658 2.3017181988613517e-6 7.329749193877657e-5 -9.05469453996298e-5 0.0001446155408061753 -6.879374571007728e-5 -0.00010567157307968731 0.00025469257222858786 -2.5461841984442613e-5 -8.921775163479918e-5 -3.292337289758253e-5 -6.162874644950019e-5 -0.00015518233278862407 -0.00017117241379901072; -0.0001655664644889558 -2.4175783324021016e-5 0.00012328563309410635 -3.6208862070780134e-5 0.00012751465249005327 -5.057072928623749e-5 7.24773154839838e-5 3.540030877471453e-5 -7.482033549910608e-5 -3.414070022278415e-5 2.794203567157965e-5 -3.4153796946464946e-5 7.96511477571027e-5 4.451746201505522e-5 -8.714969305890039e-5 -5.4312641509904494e-5 7.89318465873662e-5 5.178088315427262e-5 -3.5106883163142934e-6 1.58878075350402e-5 2.1371605839367046e-5 -0.00018277586502823506 -6.157904579173878e-5 -4.9800852935156366e-5 3.118368761045488e-5 -0.00013110346924382336 -3.08885836378599e-5 -1.0911708955394107e-5 0.0001470124290104923 -0.00023090296315363045 -0.00021540348541327118 0.0001009915168875794; -0.00021766011376305667 -3.078755512890666e-6 2.839021909911124e-5 -0.00019376345334013665 -0.00011610477848583309 3.335526163440348e-5 -1.4745729445648743e-5 -7.171905780112852e-5 5.170333135035716e-5 3.121378723742754e-5 -8.434236404857369e-5 -0.00014374483906565413 -1.0418579209285529e-5 6.432567717135502e-5 0.00010561051568808594 2.319811757340326e-5 -4.046737557065312e-5 -3.869227474344734e-5 -0.00015454302955312054 -0.00010619521525359778 -4.4331876766164085e-5 8.820210315615031e-5 -5.621008493610605e-5 0.0001120360302724983 3.324244427355542e-5 -4.3437246845776246e-5 -5.371995392625639e-5 -8.190996048644417e-5 7.11612865547711e-5 -5.416130260118333e-5 6.257270980720133e-5 -4.916838405327289e-5; -6.83909828309542e-5 0.00017064234192942985 3.6452273819423876e-6 -0.00011274366339208741 6.204828163603195e-5 1.0417428841114731e-5 1.2050439311076068e-5 -4.894959927391632e-6 0.00010423787093017113 -5.5720972211623234e-5 -2.6778280569181413e-5 -0.00014060637555091013 -2.8178860573155726e-5 -5.064190632552968e-5 -1.176204555819172e-5 -8.743591102544351e-5 3.863685846439374e-5 4.3467126617849264e-5 -5.702015992725871e-5 -4.0207109903043355e-5 -4.954569963744038e-5 -0.0002542624068679864 -0.00014923072818206766 0.00021753405521152396 2.7634579184757205e-5 9.234807579772346e-5 6.225997562281595e-5 5.403403260273692e-5 0.00018554194606834894 0.0001201879705534665 3.1415295909334515e-5 -6.597453544387723e-5; 8.919546196661678e-5 -0.0001282448287591681 -2.957389408295855e-5 2.3710179063664766e-5 -0.0001016116224699121 -4.4251203164857056e-5 -7.267321090953674e-5 -2.3228578109794868e-6 -5.66697857928854e-5 0.0001462013820647811 -4.171517177642441e-5 -0.00013548948520369003 6.376856233282554e-6 -3.260511623271062e-5 -0.00016177386753302316 9.526390893110783e-5 9.202365233387109e-6 -0.0002605135018010079 -4.944129652614858e-5 0.00028802795675542423 -1.1539688260066551e-5 4.082063203636615e-5 9.092876241734095e-5 0.00010130352855153247 0.00016379653932485476 -4.3985383329389464e-5 3.132176386966522e-5 4.47217531265317e-5 7.440393490209934e-5 5.986352859649072e-6 -0.00010049316954141715 -0.00013710289970081543], bias = [-8.174270973302098e-10, 7.184285289664243e-10, -1.7450232567040924e-10, 9.332793735735282e-10, -2.1602340755393993e-9, 1.6874092923294248e-10, 1.6068859954850596e-9, 4.3400262118453674e-10, -1.1026432013236038e-9, -9.635693389537176e-10, -3.32168460946241e-9, -4.613230576875947e-10, -1.021534784968265e-9, 4.176994801023114e-9, -5.92811736060824e-10, 6.58790844305803e-10, -4.3864842510842965e-9, 1.350571763465696e-9, -1.4799125581209653e-9, -2.7845055583250854e-9, -6.227606752404733e-10, -2.3767299780481266e-9, 3.1973584739646025e-9, 2.079326438695634e-10, -2.1908972236711624e-9, -9.42976177444197e-10, -2.1440558888986303e-9, -1.3054218448877076e-9, -2.0198743828386726e-9, -3.421505907935848e-9, 1.2291123963866255e-10, -7.498082171953474e-10]), layer_4 = (weight = [-0.0008495808371169474 -0.0005954397257069386 -0.000696699824362647 -0.000664924875838767 -0.0007395343535249846 -0.0006790269017299368 -0.0007720618147121237 -0.0007792206358863066 -0.0005740164099624967 -0.0006508820328374556 -0.0006432574054288276 -0.0006352502327347465 -0.000534024556673348 -0.0007136855015684462 -0.0006808820492755729 -0.0005139472925882497 -0.0007086312108868869 -0.0008320121205739973 -0.0005697201827018768 -0.0006236100563682626 -0.000795314481566682 -0.0006277727551468119 -0.0007354369449959826 -0.0008082877553423258 -0.0008623395777290246 -0.0006684268374502794 -0.0007664802210008924 -0.0007981661750128714 -0.0005478672579405271 -0.0005837243684817794 -0.0007778522870823012 -0.000776312528451077; 1.3003110213972118e-5 0.00020500484282893675 0.00032954191941390025 0.00029139215386867276 0.0002874273246808693 0.00017492275242789401 0.0002690819590288176 0.0002728363205257253 0.00014087160079308874 0.0002247578855069955 0.0001392147361864818 0.00020204356842004877 0.00022396674145364463 0.00024213630358148244 0.00022125028773852204 0.00029828454821209253 0.0001366731964899166 0.00036280151529130246 0.0001922316091607233 0.00011469089780267154 0.000437721960344828 0.00010283467437986819 0.00014145094692089884 0.0002680577055051008 0.00020872458760218525 0.0002901908349174092 0.00014557113032563664 8.479666423965536e-5 0.0005174843571809462 0.00025362461802528734 0.00037262349029085776 0.0003102562184078154], bias = [-0.0006868442094339936, 0.0002198772894972571]))Visualizing the Results
Let us now plot the loss over time
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")
lines!(ax, losses; linewidth=4, alpha=0.75)
scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)
fig
endFinally let us visualize the results
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
s3 = scatter!(
ax,
tsteps,
waveform_nn_trained;
marker=:circle,
alpha=0.5,
strokewidth=2,
markersize=12,
)
axislegend(
ax,
[[l1, s1], [l2, s2], [l3, s3]],
["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
position=:lb,
)
fig
endAppendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.12.6
Commit 15346901f00 (2026-04-09 19:20 UTC)
Build Info:
Official https://julialang.org release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 4 × AMD EPYC 9V74 80-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-18.1.7 (ORCJIT, znver4)
GC: Built with stock GC
Threads: 4 default, 1 interactive, 4 GC (on 4 virtual cores)
Environment:
JULIA_DEBUG = Literate
LD_LIBRARY_PATH =
JULIA_NUM_THREADS = 4
JULIA_CPU_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0This page was generated using Literate.jl.