Training a Neural ODE to Model Gravitational Waveforms
This code is adapted from Astroinformatics/ScientificMachineLearning
The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl
Package Imports
using Lux,
ComponentArrays,
LineSearches,
OrdinaryDiffEqLowOrderRK,
Optimization,
OptimizationOptimJL,
Printf,
Random,
SciMLSensitivity
using CairoMakieDefine some Utility Functions
Tip
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector
function one2two(path, m₁, m₂)
M = m₁ + m₂
r₁ = m₂ / M .* path
r₂ = -m₁ / M .* path
return r₁, r₂
endNext we define a function to perform the change of variables:
@views function soln2orbit(soln, model_params=nothing)
@assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
if size(soln, 1) == 2
χ = soln[1, :]
ϕ = soln[2, :]
@assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
p, M, e = model_params
else
χ = soln[1, :]
ϕ = soln[2, :]
p = soln[3, :]
e = soln[4, :]
end
r = p ./ (1 .+ e .* cos.(χ))
x = r .* cos.(ϕ)
y = r .* sin.(ϕ)
orbit = vcat(x', y')
return orbit
endThis function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d_dt(v::AbstractVector, dt)
a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
b = (v[3:end] .- v[1:(end - 2)]) / 2
c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
return [a; b; c] / dt
endThis function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d2_dt2(v::AbstractVector, dt)
a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
return [a; b; c] / (dt^2)
endNow we define a function to compute the trace-free moment tensor from the orbit
function orbit2tensor(orbit, component, mass=1)
x = orbit[1, :]
y = orbit[2, :]
Ixx = x .^ 2
Iyy = y .^ 2
Ixy = x .* y
trace = Ixx .+ Iyy
if component[1] == 1 && component[2] == 1
tmp = Ixx .- trace ./ 3
elseif component[1] == 2 && component[2] == 2
tmp = Iyy .- trace ./ 3
else
tmp = Ixy
end
return mass .* tmp
end
function h_22_quadrupole_components(dt, orbit, component, mass=1)
mtensor = orbit2tensor(orbit, component, mass)
mtensor_ddot = d2_dt2(mtensor, dt)
return 2 * mtensor_ddot
end
function h_22_quadrupole(dt, orbit, mass=1)
h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
return h11, h12, h22
end
function h_22_strain_one_body(dt::T, orbit) where {T}
h11, h12, h22 = h_22_quadrupole(dt, orbit)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
h11 = h11_1 + h11_2
h12 = h12_1 + h12_2
h22 = h22_1 + h22_2
return h11, h12, h22
end
function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
# compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2
@assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"
h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
@assert mass_ratio ≤ 1 "mass_ratio must be <= 1"
@assert mass_ratio ≥ 0 "mass_ratio must be non-negative"
orbit = soln2orbit(soln, model_params)
if mass_ratio > 0
m₂ = inv(T(1) + mass_ratio)
m₁ = mass_ratio * m₂
orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
return waveform
endSimulating the True Model
RelativisticOrbitModel defines system of odes which describes motion of point like particle in Schwarzschild background, uses
where,
function RelativisticOrbitModel(u, (p, M, e), t)
χ, ϕ = u
numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
denom = sqrt((p - 2)^2 - 4 * e^2)
χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = numer / (M * (p^(3 / 2)) * denom)
return [χ̇, ϕ̇]
end
mass_ratio = 0.0 # test particle
u0 = Float64[π, 0.0] # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4) # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize) # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, eLet's simulate the true model and plot the results using OrdinaryDiffEq.jl
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)
axislegend(ax, [[l, s]], ["Waveform Data"])
fig
endDefining a Neural Network Model
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but in case you do use them, make sure to mark them as const.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,
const nn = Chain(
Base.Fix1(fast_activation, cos),
Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)((layer_1 = NamedTuple(), layer_2 = (weight = Float32[3.716927f-5; -2.6111853f-5; -0.00013429068; -1.6963664f-5; -4.390978f-6; -9.617643f-5; -1.6005299f-5; -5.390959f-5; 0.00013363548; 7.7396413f-7; 0.00016138454; -8.185639f-5; -0.00012026416; 5.377302f-5; 6.606066f-6; -3.749212f-5; -0.0001277686; 2.9509436f-5; -0.00017500245; 0.00012975326; -4.599402f-6; -0.00010160609; 7.2088656f-6; -5.9561593f-5; -1.2328904f-6; -9.366829f-5; -0.00014547425; -0.00012267978; -7.87589f-5; 7.236667f-5; -0.00018484877; -2.1047694f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-7.833622f-5 -9.0659094f-5 -0.00017831552 0.00010298222 -0.00016954816 5.7020447f-5 0.00013788661 2.3032771f-5 7.769316f-5 4.3730503f-5 -0.00012917105 3.1761196f-5 -7.907841f-5 -9.641745f-5 7.9184116f-5 8.2341256f-5 -9.3846436f-5 0.00012329347 -0.00019741178 -0.00026166023 5.7001318f-5 7.398775f-5 -9.806486f-5 -1.5576748f-5 5.670079f-5 -4.25621f-5 1.9526859f-5 -8.0165555f-5 0.00013041554 -0.00013812473 0.00012598032 -0.0002129066; -4.8030965f-5 7.5902085f-6 0.0002737451 -4.6190748f-5 -4.3261578f-5 -0.000121297795 -2.5789162f-5 -3.3478802f-5 7.701093f-6 8.264446f-6 -6.9757516f-5 1.2910644f-5 -0.0001788774 -5.9658352f-5 0.00010760614 -2.6571672f-5 7.108485f-5 0.00011975272 -0.00020680911 8.3803134f-5 -0.00013002247 8.628767f-5 -0.00015753016 8.701967f-5 -0.00014644174 0.00013760726 1.7208462f-6 -6.677373f-5 0.00010416414 8.304564f-5 2.9323406f-5 1.9469806f-5; -8.732176f-5 0.0001068196 0.00018284073 -8.4317624f-5 -2.4848263f-5 -8.741085f-6 -3.5246925f-5 -8.149279f-5 -0.00015272717 4.3094275f-5 -0.00012182633 -2.9712059f-5 -1.6733975f-5 -4.99638f-5 0.00016226151 0.00011146583 -0.0002420641 5.554511f-5 2.1355298f-5 9.347041f-5 0.00016257988 -0.00022510483 -1.384184f-5 5.842705f-5 0.000261466 -8.680946f-6 0.00017960112 -7.817278f-5 6.784902f-5 5.1294086f-5 -0.00011986551 -1.5583993f-5; -3.68299f-5 0.00019802069 -4.4198296f-5 -0.00015600912 6.99726f-5 -6.908493f-5 0.00013922976 -0.00010274967 -9.903857f-5 2.3738547f-5 -0.000120019315 -2.0729472f-5 7.94604f-5 0.00014703581 0.00014958976 -0.00020287225 -3.6736044f-5 -5.463704f-5 3.2744026f-7 -3.4080564f-5 7.53509f-5 0.00015251296 1.2227935f-5 9.096607f-5 -6.5638424f-6 1.3608475f-5 4.4823995f-5 -0.00010672431 5.5725897f-5 6.376902f-5 4.702322f-5 -9.619114f-5; -8.844378f-5 -2.0181045f-5 -4.3032487f-5 -2.5969335f-5 3.8907194f-5 2.5626521f-5 -4.4135657f-5 -0.00018229098 -0.00020216509 -4.772066f-5 6.0562423f-5 6.010903f-5 7.0049144f-5 6.942316f-6 1.6562015f-5 -0.00011320809 -0.00017700554 -5.2199972f-5 4.053991f-6 -1.8616247f-5 -9.1034875f-5 -0.00015031581 5.5064032f-5 -3.5596855f-5 -7.762277f-5 0.00011078272 -1.7292305f-5 0.00010140948 1.8541294f-5 -0.00014827967 -5.1648465f-5 4.5021425f-5; 1.6940838f-5 6.251352f-5 -2.8957382f-5 4.3464424f-6 9.593187f-5 -0.00013919068 -8.502739f-5 4.900282f-5 -7.956753f-5 -2.847726f-5 5.8360424f-6 0.00011984947 -1.7608354f-5 0.00014920453 -4.7713922f-5 -4.7398888f-5 -0.00014933883 -0.00010056109 0.00020868417 5.441596f-5 -0.00013066932 2.713982f-5 -6.248878f-5 1.9317542f-6 -0.00017261671 6.59575f-6 -0.00028429538 0.00015771056 -0.00017822719 -9.849262f-5 0.00010684498 0.00010559843; 3.8812013f-5 1.7585657f-5 -6.4847427f-7 -0.00014565654 1.3570447f-5 -6.83895f-6 -0.0001469216 0.00019994081 3.5926023f-5 -0.00013207922 -6.280985f-6 -7.5894866f-5 -0.00015178788 -4.3903656f-6 -1.0448497f-6 2.4838539f-5 9.004647f-5 0.00017576733 -5.2525276f-5 8.278802f-5 2.6003427f-5 5.1232062f-5 -2.097045f-5 0.00012711735 -6.155667f-5 -5.3909647f-5 -4.943867f-8 5.685539f-5 2.3001721f-5 -1.5956994f-5 -0.00014186566 2.6107911f-5; -9.100449f-5 -7.458519f-5 0.00013374646 -2.48003f-5 -0.00012241106 -1.1952276f-5 -8.2439554f-5 -7.222173f-5 8.754079f-5 6.398831f-5 0.00015502288 0.00010335615 -1.4768757f-5 7.633288f-5 1.1951748f-5 -0.00014571763 -0.00012213204 -7.242914f-5 1.636633f-5 4.6463345f-5 -5.5862245f-5 9.8087556f-5 7.2210954f-5 -4.9362105f-5 -2.7617452f-6 -5.0724153f-5 3.6157475f-5 -8.8665394f-5 -0.00011222857 -3.6703983f-5 3.7252194f-5 0.00012249197; -3.5480123f-5 0.00011560096 0.0002199624 -3.3248064f-6 5.1658586f-5 6.0443697f-5 -1.6553397f-5 -9.1214126f-5 -9.5720854f-5 1.6152609f-6 1.1640258f-5 3.7664176f-5 5.509542f-5 -2.1765063f-5 -4.6351997f-5 -5.9476606f-5 0.00012897137 0.00012364426 0.00012849364 3.654951f-5 6.538432f-5 -3.550375f-5 -8.149103f-5 -5.9806825f-6 2.0389733f-5 1.1459709f-5 -8.8942805f-5 -2.8099428f-5 -0.00018198068 -7.569957f-5 -0.00022800018 6.432197f-5; 5.777848f-5 -0.00011292801 -0.00014195703 -1.1054843f-5 -0.00010058854 -0.00016639382 5.8523183f-5 0.0001431812 0.00012164893 6.586619f-5 2.886818f-5 2.0868989f-5 -2.0587971f-5 8.858166f-5 -8.3205254f-5 -4.5617242f-5 -0.0001459522 0.00020747105 2.425014f-5 -1.6350947f-5 -4.9321265f-5 6.36044f-5 0.00010898129 0.00020642379 -4.8966936f-5 8.644974f-5 -5.7717803f-5 -0.00021351418 -7.274227f-5 -7.811742f-5 6.5957f-5 -3.3290362f-5; 0.00024956098 -5.1282568f-5 7.981754f-5 -4.6661648f-5 0.000114341376 4.2232912f-7 -0.000106257925 -0.00011639601 -2.3086133f-5 0.00018703078 -3.3976266f-5 -0.00013550873 3.6652895f-5 9.613611f-5 -2.9022685f-5 6.4287044f-5 -0.00014381374 0.00010364166 -8.892303f-5 7.452754f-5 9.887027f-5 -9.421682f-5 -0.0002026603 0.00010918981 -0.00016960436 4.2372187f-5 5.1249703f-5 -3.194172f-5 4.840938f-5 0.00019627644 -2.5480582f-5 -0.00010027883; -0.00014391079 -6.3885585f-5 -0.00011111438 -2.003792f-5 -5.0206498f-5 -0.00020767676 5.3599102f-5 4.6606703f-5 5.1054136f-5 -1.9856947f-5 0.00011528228 -1.5871867f-5 0.00016163036 -7.7016475f-6 3.4710964f-5 3.7933267f-5 0.00016797154 3.813598f-5 -2.537108f-6 -5.3779873f-5 -7.827577f-5 -4.0777217f-5 0.00020332343 3.965994f-6 -6.924718f-6 -0.00010124564 0.00015914552 -8.0176265f-5 -2.354826f-5 5.0976436f-5 -1.1170044f-5 -3.0902673f-5; 8.721785f-5 6.6008256f-6 0.000108835105 -4.419477f-5 5.8847978f-5 -7.150986f-5 -6.2238796f-5 0.0001090553 7.8763725f-5 0.00012872426 8.9307876f-5 4.7515596f-6 0.00017439929 5.8016125f-5 7.890992f-5 -1.180816f-5 -3.0807212f-5 -0.00012134596 -1.5369822f-5 -4.4627843f-5 -4.9558123f-5 -0.00011639994 3.5199027f-5 0.00016006379 8.808025f-5 -7.99903f-6 1.7738503f-5 2.8599641f-5 7.060456f-5 -9.555144f-5 7.707221f-6 6.379063f-5; -0.00010474412 -3.9810908f-5 -7.888852f-5 -3.1952397f-5 2.7732662f-5 -0.00010889553 0.00023401427 -2.4963876f-5 6.532712f-5 -1.3751988f-5 0.00010981608 -0.00012896824 0.00018244961 0.00018385073 5.6118486f-5 5.428046f-5 -0.00014475768 3.1649197f-5 3.2342843f-5 8.718302f-5 -0.00015481829 -1.0068679f-5 -7.105085f-5 -5.642527f-5 -2.3868757f-5 7.791789f-5 -5.9662656f-5 4.174107f-5 -7.849279f-5 -0.00018561565 -4.087155f-5 -0.00010688382; -5.3682073f-5 0.00018365237 2.6042266f-5 -0.000116997806 0.0001229185 -5.539661f-5 6.4528766f-5 -8.3650346f-5 7.5186086f-5 6.303122f-5 -0.00015653147 0.0001507555 6.331997f-5 0.00025131766 -0.00017248059 6.317724f-5 7.142428f-5 -2.404541f-5 -0.00011008297 3.3113403f-5 2.4153633f-5 0.00014495727 -8.9780966f-5 -0.00019580168 0.000118821874 -0.00020379442 -7.848175f-5 -0.00014244198 6.5542605f-5 0.00010744566 -0.00011143968 7.0938084f-5; 1.984698f-5 -5.227412f-5 -4.8795355f-5 -2.4529083f-5 1.4417592f-5 4.8421403f-5 -3.2002395f-6 -3.2498247f-5 -1.0917053f-5 8.832515f-6 -0.00010409606 7.561985f-5 1.0877449f-5 -4.309647f-5 -2.0233978f-5 0.00011101049 -8.918288f-5 -0.00010879576 -4.7318343f-5 3.7994778f-5 -6.5962005f-5 -6.623178f-5 -9.096741f-6 0.000102773665 -5.640715f-5 -3.2481443f-5 0.00010519572 6.215005f-5 -0.0001298687 0.00020725597 8.549057f-5 0.00015061577; 0.0001340826 0.0002052822 3.9922983f-5 5.86565f-5 1.3555641f-5 -0.000108109845 3.9231025f-5 -6.581089f-5 -0.0001518847 -5.6295183f-5 -0.00015236514 7.284123f-5 4.3308755f-5 8.34022f-5 -0.00017784104 -0.00013945418 -0.00011588758 -8.6009386f-5 -2.74167f-5 1.3487998f-5 -3.104901f-5 -4.6772213f-5 0.00013365848 -2.5158448f-5 -3.9158476f-5 2.2524133f-5 -9.670293f-6 2.4563558f-5 -2.1653641f-5 6.267529f-5 3.8099697f-6 4.93574f-6; -1.5588678f-6 0.00018775993 -0.0001367862 -8.59321f-5 0.0002464197 -9.99516f-6 -9.926919f-5 -7.687929f-5 -1.3482467f-5 9.206699f-5 0.00011950807 -0.000116259325 -4.5332872f-5 4.1797837f-5 -4.1916905f-5 -8.225919f-5 3.7773323f-5 2.6895048f-5 -0.000100547644 4.5091936f-5 4.7891925f-5 -2.5007755f-6 -0.000114185954 -0.00015705258 1.0096115f-5 6.421935f-5 8.393536f-5 -1.049586f-6 -2.4579893f-5 -6.5176486f-5 8.2596176f-5 -1.1553753f-5; -2.912669f-5 0.00012392862 0.00011915789 -5.421777f-5 0.0002541646 8.225256f-5 -0.00020854067 4.1710362f-5 -8.8856286f-5 8.558801f-5 -2.097369f-5 0.00015598371 5.2653122f-5 -0.0002055875 9.433805f-5 1.1727972f-5 -3.778504f-5 -3.6743188f-6 8.785512f-5 -3.7939175f-5 -4.2426695f-5 -4.9558428f-5 -6.8782896f-7 4.372718f-5 -1.9269692f-5 7.0120754f-5 3.9525596f-5 4.0915128f-7 -3.3088105f-5 1.37387005f-5 2.9547093f-5 -0.00010335999; 3.1470896f-5 -2.481561f-5 -0.00012341794 -0.00018672358 -1.0810392f-5 4.4483502f-5 -4.9433947f-5 -2.8938033f-5 -2.1741656f-5 -0.00014114342 -0.00012850799 1.330426f-5 0.00021834142 0.0002330481 0.00010315862 1.5048204f-5 1.143491f-5 -0.00019161591 -6.0972186f-5 1.5716394f-5 -0.00022811047 -0.00018015086 -3.521353f-5 -0.00011195633 7.233197f-5 -7.825892f-5 -4.9662212f-5 0.000107700835 -0.00013293728 -0.0001454553 -8.166504f-5 8.924431f-5; 8.9229456f-5 0.00014880254 -5.0077888f-5 -9.3796705f-5 2.4615236f-5 9.179337f-5 7.602568f-5 -6.949338f-5 -3.6809267f-6 0.00012813041 8.935699f-5 -0.000103328464 -5.343038f-6 -9.42425f-5 8.401421f-5 -8.970936f-5 2.9823113f-7 8.993827f-6 -1.1835061f-5 5.323666f-5 3.1288168f-6 -2.5631925f-5 0.00020894644 -2.1984926f-5 6.0910825f-5 -0.00012484302 0.0002765111 1.0239253f-5 9.80667f-5 0.00013677342 0.000116190524 -4.840743f-5; -9.2146285f-5 2.9030953f-6 0.00017371029 2.1590646f-5 5.3755124f-5 -4.379023f-5 3.1170042f-5 -2.2900544f-5 0.0001220594 3.351359f-5 -0.00014922324 -4.0173654f-5 6.471675f-5 6.84296f-5 -4.9324943f-5 -0.00016036996 7.663069f-8 5.0785093f-5 -0.00011480265 6.690088f-5 4.074064f-5 5.7352598f-5 -4.2880278f-5 6.764742f-5 -2.3207478f-5 -0.00010077766 9.198569f-5 7.9389465f-5 -2.9775223f-5 1.1656523f-6 1.4899785f-5 2.8565975f-5; -0.00020908889 6.6745066f-5 -4.399565f-5 6.7992354f-5 6.629685f-5 -0.00011957564 4.115352f-5 -1.0065465f-5 -0.00011494478 -6.854345f-5 -2.1752288f-5 -5.032769f-5 3.934831f-5 5.3035732f-5 -0.00012379854 4.164438f-5 -5.5005108f-5 -3.8131428f-5 3.6281062f-6 -9.0597336f-5 -1.381356f-5 -0.00018538363 -5.5708933f-6 -5.1195934f-6 3.3587905f-5 -6.110845f-6 -6.468749f-5 9.59883f-6 -0.000111511945 -6.61314f-5 -2.7444073f-6 3.2715674f-5; 0.00011172763 1.1042442f-5 -7.514385f-5 -0.00016608486 -6.097473f-7 4.6959853f-5 -0.00013300209 6.490238f-6 4.8559254f-5 1.8819478f-6 0.00011119135 6.844895f-6 -8.3936495f-5 -0.00019284806 0.00014317436 -2.4512145f-5 -0.00012278296 6.5170234f-6 2.656782f-5 -0.00012329644 -1.1318068f-5 3.626899f-5 9.009166f-5 0.00012426359 -1.7051463f-5 -4.592223f-5 -0.00012001533 -3.790579f-5 3.435671f-5 -0.00023067955 -8.559489f-5 0.00010510046; -2.4972725f-5 -8.7235916f-5 -8.142096f-5 -0.000101135425 9.9273144f-5 -0.00024057881 0.00010435038 4.664206f-5 -4.7349746f-5 0.0002562196 -5.9238886f-5 -0.0004306709 -9.496827f-5 -7.448885f-5 8.243577f-6 -5.0161092f-5 2.7848195f-5 -0.0001296736 -0.00014451238 6.515398f-5 -0.000119628385 -7.252821f-5 -2.7009246f-5 1.50856f-5 8.111096f-5 -6.6475026f-5 -4.8613645f-5 0.00022996077 -0.00023616204 3.5261393f-5 4.3511696f-5 -9.784951f-5; 6.8015404f-5 -4.4740616f-5 6.652952f-5 0.00013763845 8.0773925f-6 8.730432f-5 7.313562f-5 7.147448f-5 7.891102f-5 8.789259f-6 -5.445567f-5 -0.00034133455 0.00010918893 2.0441616f-5 -0.00013518268 0.00011048911 6.2313907f-6 -2.5147203f-5 -8.860237f-5 -0.0001137372 9.972396f-5 0.00017153968 -5.4588323f-5 0.0002279253 0.00016238679 -0.0001821338 9.273529f-5 0.0001776397 -3.2849457f-5 -1.4235366f-5 -0.00016451796 5.9070648f-5; 6.330412f-5 -0.00012914263 -0.00013458008 -2.0438334f-5 5.846129f-5 0.00010642458 -7.2770545f-6 0.000112980924 0.000104879764 -3.5866124f-5 -1.782452f-5 2.4792396f-6 5.8074507f-5 1.1385178f-5 -8.888748f-6 -2.011565f-5 -7.755118f-5 -7.0844886f-5 0.0001314318 -0.00015627913 0.00016514778 -7.237318f-5 4.133092f-5 -4.9497805f-5 -2.4725887f-5 6.311835f-5 6.306586f-5 5.6674493f-5 -4.6200956f-5 -0.00015942418 -8.848642f-5 0.00014979161; 4.6955893f-6 9.093342f-5 4.1063235f-5 0.00020545695 0.0001028775 0.000121175595 8.570603f-5 -1.6719609f-5 0.00016332143 0.000101051104 -3.4609697f-5 8.406726f-6 -4.292593f-5 1.5576281f-5 0.0002199005 -9.129358f-5 0.00014441837 -2.2241515f-5 -5.1293322f-5 -0.00019900149 -6.010282f-5 1.9839303f-5 -0.00015341117 7.288271f-7 0.00010483554 -5.4356653f-5 3.5231045f-5 1.5453901f-5 -6.600961f-5 -0.00011730957 0.00019940067 -9.775669f-5; -0.00015526955 0.0001143568 2.544481f-5 -4.6744854f-6 -0.00016997555 -3.278541f-5 -5.5276654f-5 -0.00016939126 -4.860791f-5 -4.8032565f-5 -5.7503257f-5 -3.193274f-7 -4.243822f-5 -8.425175f-5 -0.00012909166 1.796372f-5 0.00013342076 3.2027376f-5 -9.655807f-6 5.4799813f-5 -7.168403f-5 7.982856f-5 -6.175159f-5 3.8001854f-5 -5.713584f-5 4.481294f-6 -5.9594633f-5 -6.669448f-5 -9.6208845f-5 0.00012135006 -3.3904587f-5 2.2566207f-6; 0.00012224312 0.0001296244 5.7501296f-5 0.00014206508 1.2769536f-5 0.00019203579 0.00019282063 8.271343f-5 -6.254315f-5 -3.848317f-5 -4.1292453f-5 0.00019807463 6.43839f-5 9.530599f-5 -1.5297124f-5 1.289264f-5 2.5294052f-5 -4.135177f-5 -5.0077633f-5 -0.00015329951 0.0001222338 -0.00019659745 -0.00016143918 1.42967065f-5 3.5044827f-6 -0.00012797186 -0.000113645365 -0.00013620606 -2.6822228f-5 -9.718565f-5 -0.00012476559 -7.391155f-5; -5.9672704f-5 -1.9146342f-5 8.963416f-5 0.0001881234 -6.720164f-5 9.944554f-5 0.00017215095 0.00010243298 -0.00010418117 -0.00011452576 -4.6432397f-6 0.00015797996 0.000104850944 -1.2040077f-6 -0.00016408807 -0.00014571285 -0.00014949306 1.1166255f-5 4.584382f-5 0.000108160726 7.154367f-5 9.838442f-6 6.809814f-5 -0.00011639075 9.55975f-5 6.159785f-5 -9.039192f-6 0.00010461259 -2.6493535f-5 0.00017312868 1.7604047f-5 8.790748f-5; 3.382155f-5 3.7756316f-5 4.879886f-5 5.74008f-5 -9.567592f-5 8.993656f-7 7.6129065f-5 -6.737143f-5 0.00012690313 4.9057933f-5 2.0108626f-5 -7.5060205f-5 -2.0633972f-5 -6.934283f-5 -1.9525964f-5 0.0001348091 5.4505304f-5 5.5560788f-5 1.6578115f-5 -0.0001865471 -2.4319597f-5 -0.0001519099 3.1634478f-5 -3.550371f-5 -3.9402716f-5 -3.651548f-5 9.565178f-6 -0.00013488335 -0.00011634265 6.828646f-5 -7.290949f-5 2.2324555f-6], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[0.000111012094 -9.059404f-5 2.5151568f-5 3.0433212f-5 7.591507f-5 5.727616f-5 -0.00019341958 2.562336f-6 -3.1769048f-6 -8.578407f-5 -2.3046547f-5 1.2654811f-5 9.665946f-5 0.00011644564 -1.6966182f-5 -0.00016355839 1.5083246f-5 7.360183f-5 9.2210554f-5 3.4485056f-5 -4.0949167f-6 -3.7540667f-6 3.4575398f-5 -6.016494f-7 0.0001934 0.0001227004 9.221884f-6 6.550231f-5 -0.00013388741 4.726135f-5 -9.326498f-5 -6.1080995f-5; -0.000107211024 -7.994472f-5 4.592834f-5 -0.0003045256 5.5377255f-5 7.2337134f-5 -5.5731773f-5 0.00015415803 0.000107046355 0.00011938442 -9.688661f-5 0.00020759592 3.6062174f-5 1.855832f-5 -1.0287014f-5 -7.891773f-5 6.511355f-5 -1.5118984f-5 -0.0001939485 -7.970189f-5 -3.0102905f-5 6.613885f-5 -1.5822001f-5 -1.5074148f-5 -0.00026889678 0.0001735523 -2.4534085f-5 -2.1494965f-5 0.00011210759 0.00020532901 -0.000119213604 -6.7167945f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))Similar to most deep learning frameworks, Lux defaults to using Float32. However, in this case we need Float64
const params = ComponentArray(f64(ps))
const nn_model = StatefulLuxLayer(nn, nothing, st)StatefulLuxLayer{Val{true}()}(
Chain(
layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
layer_2 = Dense(1 => 32, cos), # 64 parameters
layer_3 = Dense(32 => 32, cos), # 1_056 parameters
layer_4 = Dense(32 => 2), # 66 parameters
),
) # Total: 1_186 parameters,
# plus 0 states.Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses
where,
function ODE_model(u, nn_params, t)
χ, ϕ = u
p, M, e = ode_model_params
# In this example we know that `st` is am empty NamedTuple hence we can safely ignore
# it, however, in general, we should use `st` to store the state of the neural network.
y = 1 .+ nn_model([first(u)], nn_params)
numer = (1 + e * cos(χ))^2
denom = M * (p^(3 / 2))
χ̇ = (numer / denom) * y[1]
ϕ̇ = (numer / denom) * y[2]
return [χ̇, ϕ̇]
endLet us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
axislegend(
ax,
[[l1, s1], [l2, s2]],
["Waveform Data", "Waveform Neural Net (Untrained)"];
position=:lb,
)
fig
endSetting Up for Training the Neural Network
Next, we define the objective (loss) function to be minimized when training the neural differential equations.
const mseloss = MSELoss()
function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
return mseloss(pred_waveform, waveform)
endWarmup the loss function
loss(params)0.0007301543672369855Now let us define a callback function to store the loss over time
const losses = Float64[]
function callback(θ, l)
push!(losses, l)
@printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
return false
endTraining the Neural Network
Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
optprob,
BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
callback,
maxiters=1000,
)retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [3.716926948976947e-5; -2.611185300336874e-5; -0.00013429067621446552; -1.696366416579309e-5; -4.390978119769373e-6; -9.617643081564518e-5; -1.6005298675731737e-5; -5.390958904166401e-5; 0.00013363547623141433; 7.739641318944038e-7; 0.00016138453793236623; -8.185639308059386e-5; -0.0001202641578856103; 5.3773019317324434e-5; 6.6060661083585e-6; -3.7492118281092473e-5; -0.0001277685951207862; 2.950943598987276e-5; -0.00017500245303370333; 0.00012975325807924824; -4.599402018347606e-6; -0.00010160609235741442; 7.208865554271469e-6; -5.9561592934139804e-5; -1.232890440404434e-6; -9.366829181093949e-5; -0.00014547424507292635; -0.0001226797758136611; -7.875890150883816e-5; 7.236666715457325e-5; -0.00018484877364235235; -2.1047693735453442e-5;;], bias = [6.030779500382455e-17, -1.942298629434973e-17, -3.034023945607868e-16, -2.43396742076359e-17, -3.698831221298329e-18, -2.1231394079210504e-16, -5.2810110168881206e-18, -6.619343694356377e-17, 1.790350231585404e-16, 5.463565258481761e-19, 1.1148465034283361e-16, -9.197488041288491e-17, -8.603453065857364e-17, -2.2193543664643123e-17, 2.9313834477053868e-18, 1.570552688013459e-17, -9.126669564252597e-17, 2.4355641303671953e-17, -1.2853837011122913e-16, 5.008327128198985e-17, -1.0775383936512136e-17, -1.4538867584862042e-16, 4.8916051947994554e-18, -5.281272120394619e-17, -1.47326883447741e-18, 4.2420458427926167e-17, -3.941431795522927e-16, 7.743560799343103e-17, -1.7425535158178745e-16, 1.571515220380806e-16, -1.0554813190689853e-16, -1.0150152795387154e-18]), layer_3 = (weight = [-7.83393854898274e-5 -9.066226017138656e-5 -0.00017831868286418554 0.00010297905713842475 -0.0001695513285620957 5.701728036103047e-5 0.00013788344835221055 2.3029605236943133e-5 7.768999670691018e-5 4.3727336325088585e-5 -0.00012917421668941087 3.175802952784062e-5 -7.908157682220641e-5 -9.64206130983166e-5 7.918094956140232e-5 8.23380894336045e-5 -9.38496025436278e-5 0.00012329030754800233 -0.00019741494293963 -0.0002616633949027112 5.699815186845334e-5 7.398458432464248e-5 -9.806802814536528e-5 -1.5579914596755086e-5 5.669762207717828e-5 -4.256526632060136e-5 1.9523692336525155e-5 -8.016872130508663e-5 0.0001304123786586513 -0.0001381278937141245 0.00012597715862388686 -0.00021290976406716506; -4.803145813000191e-5 7.589714902702942e-6 0.00027374460595367946 -4.619124112011686e-5 -4.326207155188371e-5 -0.00012129828857084413 -2.5789656058780926e-5 -3.347929551907594e-5 7.700599132503811e-6 8.263951974210211e-6 -6.975800962479192e-5 1.2910150223458681e-5 -0.00017887788882305438 -5.965884585109318e-5 0.00010760564444992865 -2.6572165291302732e-5 7.108435355533351e-5 0.00011975222995445321 -0.00020680960616388622 8.380264015343485e-5 -0.00013002296210558908 8.628717595186425e-5 -0.00015753065118858414 8.701917366763809e-5 -0.00014644223730372088 0.0001376067678275177 1.7203525767675321e-6 -6.677422671870843e-5 0.00010416364356117599 8.304514748217799e-5 2.9322912530035508e-5 1.9469312325806313e-5; -8.732108037796292e-5 0.00010682027711888369 0.00018284140312467843 -8.431694661843056e-5 -2.4847585541949784e-5 -8.740407998206916e-6 -3.524624773993553e-5 -8.149211428238532e-5 -0.00015272648963722256 4.309495195636135e-5 -0.00012182565044489724 -2.9711381308092527e-5 -1.6733297548431546e-5 -4.996312347933173e-5 0.00016226218647213262 0.0001114665072004455 -0.00024206341843231404 5.5545788904480267e-5 2.1355974944945e-5 9.347108496009541e-5 0.00016258055327349883 -0.0002251041507638866 -1.3841162604451366e-5 5.84277265938604e-5 0.0002614666671201734 -8.680268570546777e-6 0.0001796017975488778 -7.817210209899373e-5 6.784969516453736e-5 5.129476337466814e-5 -0.00011986483079957988 -1.5583316076602127e-5; -3.6829144717091733e-5 0.00019802144301367007 -4.419753973478861e-5 -0.00015600836273505327 6.997335653296701e-5 -6.908417058865735e-5 0.00013923051223401937 -0.00010274891324796754 -9.903781106685694e-5 2.3739302507704424e-5 -0.00012001855871630506 -2.0728716533529977e-5 7.946115433015538e-5 0.00014703656701614104 0.00014959051545021152 -0.0002028714960556385 -3.6735288501851206e-5 -5.463628336427724e-5 3.2819618127610643e-7 -3.407980771725445e-5 7.535165708988003e-5 0.00015251371873319862 1.2228691177836438e-5 9.096682255632188e-5 -6.563086442371996e-6 1.3609230466054774e-5 4.482475118721349e-5 -0.00010672355061387051 5.572665317459842e-5 6.376977320770635e-5 4.702397485971579e-5 -9.619038688241486e-5; -8.844788913501533e-5 -2.018515129402016e-5 -4.303659312078474e-5 -2.5973441148556514e-5 3.890308784667112e-5 2.562241519974091e-5 -4.413976290033458e-5 -0.00018229508946768843 -0.00020216919306782774 -4.772476546003776e-5 6.055831676698684e-5 6.010492274419076e-5 7.004503772243636e-5 6.938209884365562e-6 1.655790896966767e-5 -0.00011321219295149773 -0.00017700964468182965 -5.220407779797158e-5 4.049884813731364e-6 -1.8620352926672416e-5 -9.103898133270732e-5 -0.0001503199187538292 5.505992652961699e-5 -3.5600960978798616e-5 -7.762687396843279e-5 0.00011077861331052093 -1.7296410584242795e-5 0.00010140537546597276 1.8537187719462217e-5 -0.00014828377111933714 -5.165257112470837e-5 4.50173187744501e-5; 1.6938828663815158e-5 6.251151269386325e-5 -2.895939080482807e-5 4.344433433988695e-6 9.592986232452786e-5 -0.00013919269342588827 -8.502939538014346e-5 4.900081246593915e-5 -7.956953861432154e-5 -2.8479268551773948e-5 5.83403341514707e-6 0.00011984746295345214 -1.7610363032163517e-5 0.00014920251767468586 -4.77159309623228e-5 -4.7400896549543e-5 -0.00014934083530826597 -0.00010056309922513738 0.00020868216149828295 5.4413950307918006e-5 -0.00013067132469880634 2.713781130003246e-5 -6.249079244592464e-5 1.929745237040818e-6 -0.00017261871919184234 6.5937411607369495e-6 -0.00028429739261192987 0.0001577085486831311 -0.00017822919555621992 -9.849463178691353e-5 0.00010684296928665733 0.00010559642312434289; 3.8811896622409204e-5 1.7585539928749265e-5 -6.485910195081733e-7 -0.00014565665366501077 1.3570330004203162e-5 -6.839066707125946e-6 -0.00014692170986347364 0.0001999406955586701 3.592590622463801e-5 -0.00013207933840714858 -6.2811017120293295e-6 -7.589498306299066e-5 -0.00015178800128307603 -4.390482324762128e-6 -1.0449664320013512e-6 2.4838421843359335e-5 9.004635615599873e-5 0.00017576721405237453 -5.25253930096549e-5 8.278790449583505e-5 2.6003309933347665e-5 5.1231945355884237e-5 -2.0970567100103907e-5 0.00012711723650898335 -6.155678453762548e-5 -5.390976399905477e-5 -4.955541794078758e-8 5.6855274699391486e-5 2.3001604524669512e-5 -1.5957110342806906e-5 -0.00014186577788461336 2.610779450367676e-5; -9.100519466657736e-5 -7.458589036461527e-5 0.00013374575999121953 -2.4801003495595508e-5 -0.00012241176384479382 -1.1952979297273689e-5 -8.24402575052959e-5 -7.222243377684193e-5 8.754008947660983e-5 6.398760367672315e-5 0.00015502217345427744 0.00010335544564084661 -1.476946065296661e-5 7.633217526388777e-5 1.1951044394736025e-5 -0.0001457183295532743 -0.00012213274542220134 -7.242984222459231e-5 1.6365626175763466e-5 4.646264121773466e-5 -5.5862948760405344e-5 9.808685197284108e-5 7.221025068800322e-5 -4.9362808316727455e-5 -2.76244893386371e-6 -5.072485666816495e-5 3.6156771126193055e-5 -8.866609798489367e-5 -0.0001122292721665135 -3.670468661213749e-5 3.7251489880953246e-5 0.0001224912650978509; -3.5479968299297166e-5 0.00011560111515955116 0.000219962554457923 -3.324651468757154e-6 5.165874090544e-5 6.044385218210523e-5 -1.6553241674229005e-5 -9.121397097096742e-5 -9.572069911719339e-5 1.615415773630019e-6 1.164041286919884e-5 3.766433137734439e-5 5.5095575864286886e-5 -2.1764908249862123e-5 -4.635184240960465e-5 -5.9476451222691235e-5 0.00012897152401073804 0.00012364441709159923 0.00012849379918570528 3.6549665584787404e-5 6.538447842267394e-5 -3.550359333367087e-5 -8.14908759804903e-5 -5.98052765617579e-6 2.0389887414966297e-5 1.1459863618982563e-5 -8.894264983446393e-5 -2.8099273275258267e-5 -0.00018198052357383987 -7.569941744641376e-5 -0.0002280000241609768 6.43221231273846e-5; 5.7778277370780655e-5 -0.00011292821745697417 -0.0001419572365175177 -1.1055046523288014e-5 -0.00010058874631609782 -0.0001663940203779559 5.852297890854987e-5 0.00014318099402047698 0.00012164872878387675 6.586598447214797e-5 2.8867976472649463e-5 2.08687850336361e-5 -2.0588175355404724e-5 8.85814586600103e-5 -8.320545766579002e-5 -4.561744590133434e-5 -0.00014595240663580503 0.00020747084618236858 2.424993526987457e-5 -1.6351150587032495e-5 -4.932146857569008e-5 6.360419484011885e-5 0.00010898108285076902 0.00020642358849912886 -4.896714035381577e-5 8.644953759543643e-5 -5.771800733596891e-5 -0.0002135143879040322 -7.27424705455263e-5 -7.811762061279245e-5 6.595679569913289e-5 -3.329056580510933e-5; 0.00024956162383662513 -5.128192492063162e-5 7.981818203463784e-5 -4.666100425763052e-5 0.00011434201883733283 4.229723674850439e-7 -0.00010625728168176481 -0.0001163953681907249 -2.3085490051902737e-5 0.0001870314183679783 -3.397562317784479e-5 -0.00013550808648230635 3.665353802264377e-5 9.61367507034349e-5 -2.9022042094739052e-5 6.428768688695487e-5 -0.00014381309733378097 0.0001036423011601343 -8.89223854255553e-5 7.452818249933081e-5 9.887091005569334e-5 -9.421617573477533e-5 -0.00020265966008534646 0.00010919045167159227 -0.00016960371862748626 4.237282997112549e-5 5.125034590948872e-5 -3.1941076254018824e-5 4.8410023059685454e-5 0.00019627707957172667 -2.547993858113658e-5 -0.00010027818347677586; -0.00014391056076763623 -6.388535887143068e-5 -0.00011111415271787527 -2.003769343335136e-5 -5.0206271156440014e-5 -0.00020767652942939215 5.359932863595475e-5 4.660692964191109e-5 5.105436237155841e-5 -1.9856720358603156e-5 0.00011528250713692364 -1.587164017453792e-5 0.00016163058996991663 -7.701420985341278e-6 3.471119005209348e-5 3.7933493400686794e-5 0.00016797176165366003 3.813620521779679e-5 -2.5368814692984396e-6 -5.377964673587664e-5 -7.82755441139494e-5 -4.0776990248360955e-5 0.00020332366064050751 3.9662205165033024e-6 -6.924491502814451e-6 -0.00010124541488383449 0.00015914574858126584 -8.017603879469047e-5 -2.35480337539967e-5 5.097666242019748e-5 -1.1169817229332138e-5 -3.0902446200144036e-5; 8.72211719765396e-5 6.604147780723289e-6 0.00010883842767401411 -4.419144827219789e-5 5.8851300106159314e-5 -7.150653548119757e-5 -6.223547390940727e-5 0.0001090586199793085 7.87670476497041e-5 0.00012872758523049663 8.931119807714475e-5 4.7548817532962635e-6 0.00017440261288033068 5.801944714808424e-5 7.891324346606513e-5 -1.1804837616677593e-5 -3.080388998893091e-5 -0.00012134263963484398 -1.5366499813880396e-5 -4.462452054536958e-5 -4.955480037419229e-5 -0.00011639661827171505 3.520234966022787e-5 0.00016006711373525186 8.80835694046774e-5 -7.995707368130102e-6 1.774182490083331e-5 2.8602963335687728e-5 7.060788336862041e-5 -9.554811842785883e-5 7.710543083173033e-6 6.379394906294385e-5; -0.00010474531043706614 -3.981209991303451e-5 -7.888971484185635e-5 -3.1953589114492835e-5 2.773146954444964e-5 -0.00010889672446915727 0.00023401307518673865 -2.4965068205856706e-5 6.532592780735826e-5 -1.3753180410787184e-5 0.00010981488705460784 -0.00012896943260835427 0.00018244842178268713 0.0001838495383885307 5.611729395190051e-5 5.4279267005257666e-5 -0.00014475887181156813 3.16480047508219e-5 3.234165089402396e-5 8.7181827638638e-5 -0.0001548194839238582 -1.006987147826041e-5 -7.105203874794968e-5 -5.64264617246955e-5 -2.3869949273431862e-5 7.791669597230986e-5 -5.966384813341629e-5 4.173987801500042e-5 -7.849398278318309e-5 -0.00018561683717103675 -4.087274172029272e-5 -0.00010688500950033032; -5.368163269165835e-5 0.0001836528141734757 2.60427065382089e-5 -0.00011699736509624784 0.0001229189407763639 -5.539616844769613e-5 6.452920629955913e-5 -8.364990512061817e-5 7.518652697173709e-5 6.303166142745031e-5 -0.0001565310334103583 0.00015075593970097592 6.332040780536711e-5 0.0002513180994879087 -0.0001724801507793731 6.317768262080917e-5 7.142471864200982e-5 -2.40449693069227e-5 -0.00011008252781037595 3.311384412656384e-5 2.4154073487372078e-5 0.0001449577145263174 -8.978052514292212e-5 -0.00019580123929925224 0.00011882231470510451 -0.00020379398060184012 -7.848131221406379e-5 -0.00014244153439220167 6.55430455094762e-5 0.00010744610031681147 -0.00011143923924690517 7.093852460231714e-5; 1.9847370599263993e-5 -5.227373116697856e-5 -4.879496486440263e-5 -2.452869307065118e-5 1.4417981905963158e-5 4.842179328504168e-5 -3.199849511774633e-6 -3.249785663258984e-5 -1.0916663309519241e-5 8.832905000117963e-6 -0.00010409567334621276 7.562023961751903e-5 1.0877839262928718e-5 -4.309607833904914e-5 -2.0233587970623863e-5 0.00011101088307866146 -8.918249227038524e-5 -0.00010879536716432054 -4.731795274468262e-5 3.799516776295432e-5 -6.596161551472464e-5 -6.623139347114465e-5 -9.096351500674733e-6 0.00010277405522599107 -5.640675885418161e-5 -3.248105280848064e-5 0.00010519611237638629 6.215044118532056e-5 -0.00012986830847269666 0.0002072563627322504 8.549095651202667e-5 0.00015061616136760714; 0.00013408135197331305 0.00020528094825075088 3.992173657260687e-5 5.865525378120192e-5 1.3554394684054913e-5 -0.00010811109171974047 3.922977845157556e-5 -6.581213750670497e-5 -0.00015188594483098172 -5.629642913610626e-5 -0.00015236639086415812 7.283998053261399e-5 4.330750886091668e-5 8.340095297406593e-5 -0.00017784228333531826 -0.0001394554223565281 -0.00011588882847482558 -8.601063240113586e-5 -2.7417946000607776e-5 1.3486751925112096e-5 -3.1050256836730495e-5 -4.6773459359803214e-5 0.00013365723640398178 -2.5159694302978863e-5 -3.915972268934313e-5 2.252288646794795e-5 -9.67153977163158e-6 2.456231192135447e-5 -2.165488734809847e-5 6.267404165790678e-5 3.8087233066097334e-6 4.934493574713787e-6; -1.5592923526153995e-6 0.00018775950336675766 -0.00013678662328202366 -8.593252534665498e-5 0.0002464192631825399 -9.99558489893537e-6 -9.926961758792732e-5 -7.687971339946909e-5 -1.3482891356761983e-5 9.206656721034684e-5 0.00011950764490138764 -0.00011625974986855183 -4.5333296677927906e-5 4.179741261141081e-5 -4.1917329165519764e-5 -8.225961466705742e-5 3.777289855606745e-5 2.68946234220467e-5 -0.00010054806882435366 4.50915114717641e-5 4.7891500050533945e-5 -2.501200081316229e-6 -0.00011418637843489981 -0.00015705300286099602 1.0095690574596089e-5 6.42189235597455e-5 8.39349351528655e-5 -1.0500105625403677e-6 -2.4580317424626716e-5 -6.517691037770625e-5 8.259575149823119e-5 -1.1554177783912536e-5; -2.9125096671187697e-5 0.00012393021420387367 0.00011915948521156646 -5.421617647704835e-5 0.0002541661884151196 8.225415472218139e-5 -0.00020853908104079012 4.171195500923951e-5 -8.885469328568457e-5 8.558960648763774e-5 -2.0972096945192683e-5 0.00015598530403615411 5.265471513659728e-5 -0.00020558590081243145 9.433964401062424e-5 1.1729565403519743e-5 -3.778344803329618e-5 -3.6727257793887924e-6 8.785671484466777e-5 -3.7937581919395306e-5 -4.2425101537510416e-5 -4.955683511987082e-5 -6.86235932853548e-7 4.372877406233586e-5 -1.926809859475739e-5 7.012234668757163e-5 3.9527188664504026e-5 4.107443088936016e-7 -3.308651180311578e-5 1.3740293570596698e-5 2.954868573388969e-5 -0.00010335839801889086; 3.146643622640682e-5 -2.48200699246762e-5 -0.00012342239595852087 -0.0001867280438858829 -1.0814851961546011e-5 4.44790425859708e-5 -4.94384071174823e-5 -2.8942492903653438e-5 -2.1746116075853493e-5 -0.00014114788017140738 -0.00012851245217857877 1.329980069896439e-5 0.00021833695838324922 0.00023304363323005297 0.00010315416194826332 1.5043744056680487e-5 1.1430450754805898e-5 -0.00019162036868182857 -6.097664616786205e-5 1.5711934352685465e-5 -0.00022811492770815302 -0.000180155322127011 -3.521798815599674e-5 -0.00011196079093168947 7.232750841977037e-5 -7.826337922268245e-5 -4.9666672097703e-5 0.00010769637493979462 -0.00013294173503198734 -0.00014545975817271717 -8.166950236440384e-5 8.923985343754319e-5; 8.923348364231248e-5 0.00014880657258953132 -5.007385984603567e-5 -9.379267716969147e-5 2.4619263833014e-5 9.179740013830335e-5 7.602970877240633e-5 -6.94893489238167e-5 -3.676898721406997e-6 0.00012813444195824928 8.936101662734154e-5 -0.00010332443630279667 -5.339010065960613e-6 -9.423847509271592e-5 8.401823729553508e-5 -8.970533522034926e-5 3.0225914580161036e-7 8.997854685442513e-6 -1.1831032814698228e-5 5.324068920495826e-5 3.1328448600870902e-6 -2.562789736505768e-5 0.00020895046768754034 -2.1980898189504895e-5 6.091485270343043e-5 -0.0001248389918585677 0.0002765151265081418 1.024328125973839e-5 9.807073079325627e-5 0.00013677745014472432 0.00011619455185758377 -4.840340184120267e-5; -9.214544477062103e-5 2.903935663299465e-6 0.00017371112701594567 2.1591486100179545e-5 5.375596449822422e-5 -4.378939044899328e-5 3.117088274676581e-5 -2.28997032280763e-5 0.00012206023896224542 3.351442867652748e-5 -0.0001492224000429783 -4.017281387122298e-5 6.471758754569876e-5 6.843044323260121e-5 -4.9324102266735855e-5 -0.00016036912345215485 7.747108979306204e-8 5.078593315202954e-5 -0.00011480180907883357 6.690172088211175e-5 4.0741480941607166e-5 5.735343804927493e-5 -4.2879437361888114e-5 6.764826323716109e-5 -2.320663768104043e-5 -0.00010077681870087985 9.198653063149776e-5 7.939030549833374e-5 -2.97743826848914e-5 1.1664926915816955e-6 1.4900625527026081e-5 2.8566815697660093e-5; -0.00020909288998084295 6.674106423886247e-5 -4.399965093325673e-5 6.798835254883099e-5 6.629285069791266e-5 -0.0001195796429048764 4.114951678345609e-5 -1.006946678828404e-5 -0.0001149487796457844 -6.8547451354995e-5 -2.1756289988106288e-5 -5.033169129020927e-5 3.9344308043650004e-5 5.303173036807258e-5 -0.0001238025432205419 4.1640378349962066e-5 -5.500910976407189e-5 -3.8135429185081625e-5 3.624104629868383e-6 -9.060133726889271e-5 -1.3817561756901963e-5 -0.00018538763089892375 -5.574894948246443e-6 -5.123595037022496e-6 3.35839032966517e-5 -6.1148464912677425e-6 -6.46914921337124e-5 9.594828207623163e-6 -0.00011151594646363978 -6.613539864631022e-5 -2.7484089234147214e-6 3.2711672411800386e-5; 0.00011172529272757957 1.1040101044864524e-5 -7.514618985563623e-5 -0.00016608720537019495 -6.120879146238123e-7 4.6957512013848553e-5 -0.0001330044301806537 6.487897254057658e-6 4.8556913016604914e-5 1.8796071987252974e-6 0.00011118901099568735 6.842554258269553e-6 -8.393883538641415e-5 -0.00019285040391818953 0.00014317201791586736 -2.4514485202734144e-5 -0.00012278530405577512 6.51468278251802e-6 2.656548028487634e-5 -0.00012329878293652286 -1.1320408274832558e-5 3.6266649177067484e-5 9.008931599114473e-5 0.00012426125110382862 -1.7053803757653496e-5 -4.592456892109358e-5 -0.0001200176680225307 -3.79081298245523e-5 3.4354371082968224e-5 -0.00023068189105228675 -8.559722985348473e-5 0.00010509811770193692; -2.497807667693924e-5 -8.724126723407291e-5 -8.142631463283011e-5 -0.00010114077644700707 9.926779262443436e-5 -0.00024058416215280915 0.00010434502860718809 4.6636709388829036e-5 -4.73550969634235e-5 0.0002562142365188467 -5.92442372596712e-5 -0.0004306762382345961 -9.497362384149756e-5 -7.449420422360376e-5 8.238225977240536e-6 -5.0166443312059676e-5 2.784284356880836e-5 -0.00012967895213094032 -0.00014451772926415288 6.514863003235767e-5 -0.00011963373594496966 -7.253355920125462e-5 -2.701459720426803e-5 1.5080249117298887e-5 8.110560543557058e-5 -6.648037719584067e-5 -4.861899628462002e-5 0.0002299554219046064 -0.00023616739394653099 3.525604211448088e-5 4.350634504070214e-5 -9.785485940099258e-5; 6.801787864569236e-5 -4.473814218036181e-5 6.65319971336256e-5 0.0001376409204272901 8.079866772820282e-6 8.730679152685173e-5 7.313809371007359e-5 7.147695620696528e-5 7.891349420568169e-5 8.791732823371475e-6 -5.4453197237695566e-5 -0.00034133208058156845 0.00010919140228687321 2.0444089812003123e-5 -0.00013518020526353497 0.00011049158680870789 6.233864921871259e-6 -2.5144728654053248e-5 -8.859989569139572e-5 -0.00011373472764936931 9.97264364446863e-5 0.00017154215353692753 -5.4585848858917585e-5 0.00022792776914465658 0.00016238926079276722 -0.00018213131946359762 9.273776274092423e-5 0.00017764217088150284 -3.284698280586657e-5 -1.4232891346553321e-5 -0.00016451548393200347 5.907312194503581e-5; 6.330440931872367e-5 -0.0001291423412699913 -0.00013457979536230926 -2.04380447617694e-5 5.846157922444782e-5 0.00010642487007810563 -7.276765122307389e-6 0.00011298121365637405 0.00010488005330937663 -3.5865834310763416e-5 -1.782422974845257e-5 2.479528925894111e-6 5.807479659363408e-5 1.1385467710559086e-5 -8.888458855446332e-6 -2.0115361498548686e-5 -7.755089051404274e-5 -7.084459672530276e-5 0.00013143209629145076 -0.0001562788400993427 0.00016514806896832687 -7.237288979429002e-5 4.133121013676664e-5 -4.9497515523919144e-5 -2.4725597417793134e-5 6.311863956891766e-5 6.306615081068775e-5 5.6674782295364494e-5 -4.620066635158986e-5 -0.00015942388732998763 -8.848613093921389e-5 0.0001497918984891186; 4.698436397361899e-6 9.093626654363662e-5 4.106608180450024e-5 0.0002054597957351384 0.00010288034759584369 0.00012117844239697422 8.570887662359977e-5 -1.671676140971264e-5 0.00016332427408201385 0.0001010539512674466 -3.460684961170018e-5 8.409573139269902e-6 -4.2923083627781954e-5 1.5579128058801733e-5 0.0002199033428507998 -9.1290732187858e-5 0.0001444212197850449 -2.2238667542054917e-5 -5.129047490187473e-5 -0.00019899863905152446 -6.0099971550456054e-5 1.9842149823722406e-5 -0.00015340831810020363 7.316742323957294e-7 0.00010483838782535422 -5.4353805852535634e-5 3.5233892459696354e-5 1.5456748270709326e-5 -6.60067649266598e-5 -0.00011730672626866959 0.00019940352144770445 -9.775384164525681e-5; -0.00015527291616383783 0.00011435343112604778 2.5441440083516364e-5 -4.677854850187011e-6 -0.00016997892162255935 -3.278877993654355e-5 -5.528002366266174e-5 -0.00016939463312230668 -4.861128109115049e-5 -4.803593474282873e-5 -5.7506626763640806e-5 -3.2269689501950914e-7 -4.2441589176461724e-5 -8.425511696334599e-5 -0.00012909502475015808 1.796035139629783e-5 0.00013341739322567315 3.20240062130067e-5 -9.659176431152447e-6 5.479644387304797e-5 -7.168739903211579e-5 7.982519236901068e-5 -6.175496018916351e-5 3.799848418491366e-5 -5.7139209094047445e-5 4.477924630174462e-6 -5.959800255441544e-5 -6.669785289051196e-5 -9.621221470372055e-5 0.00012134668868309476 -3.3907956822793196e-5 2.253251249541074e-6; 0.0001222431450962115 0.00012962442947281785 5.7501324999241736e-5 0.00014206510918410759 1.2769564317926104e-5 0.00019203581855362537 0.00019282066154955219 8.271345673091966e-5 -6.254312079436834e-5 -3.848314058887512e-5 -4.1292424203499656e-5 0.0001980746596465844 6.43839297928872e-5 9.530602205740545e-5 -1.5297095661305086e-5 1.2892668973284565e-5 2.529408034133038e-5 -4.1351741447949056e-5 -5.0077604601762734e-5 -0.00015329947978062308 0.00012223383187046538 -0.00019659742297408847 -0.00016143915177127083 1.4296735070450468e-5 3.504511322892775e-6 -0.00012797182766990483 -0.00011364533616034075 -0.00013620602890187624 -2.682219978924587e-5 -9.718562118978915e-5 -0.00012476556053189327 -7.391152080418282e-5; -5.966944082736916e-5 -1.914307905515995e-5 8.963742219511562e-5 0.00018812666893374346 -6.71983762118387e-5 9.944880365529246e-5 0.00017215421029269867 0.00010243624636806437 -0.0001041779069169967 -0.000114522499654943 -4.639976407464228e-6 0.00015798322054925797 0.00010485420715431274 -1.2007444753054643e-6 -0.00016408480224761916 -0.0001457095895717288 -0.00014948979879827844 1.1169518063448971e-5 4.5847082208075704e-5 0.00010816398934129894 7.154693229260119e-5 9.841704911564336e-6 6.81014025644046e-5 -0.0001163874876625295 9.560076067593147e-5 6.160111296358716e-5 -9.035929008503013e-6 0.00010461585405882908 -2.6490271932370907e-5 0.00017313193892497212 1.7607309916799993e-5 8.791074286583963e-5; 3.382021397298827e-5 3.775497909116222e-5 4.879752170125442e-5 5.739946438295412e-5 -9.567725417839728e-5 8.980291155937469e-7 7.612772835015895e-5 -6.737276770111412e-5 0.0001269017980277393 4.9056596724014784e-5 2.0107289417123655e-5 -7.506154132245182e-5 -2.0635308358415186e-5 -6.934416658880428e-5 -1.9527300067081965e-5 0.00013480776625981246 5.450396794656637e-5 5.5559451099882344e-5 1.6576778322984834e-5 -0.00018654843505175417 -2.4320933773949677e-5 -0.0001519112388294368 3.163314114969791e-5 -3.550504469133194e-5 -3.940405211581284e-5 -3.6516817529285155e-5 9.563841119581803e-6 -0.00013488468310499278 -0.0001163439860491278 6.82851264329427e-5 -7.291082645934815e-5 2.2311189851939175e-6], bias = [-3.1661865188974626e-9, -4.935945739299958e-10, 6.773684321081742e-10, 7.559197119634583e-10, -4.105964939241635e-9, -2.0090039750548913e-9, -1.1674967125731643e-10, -7.036978063879465e-10, 1.5488858386446333e-10, -2.0390685661877108e-10, 6.432447024754245e-10, 2.2653350188475527e-10, 3.3221808345224987e-9, -1.1921479759737347e-9, 4.4079344806884005e-10, 3.8995010115295434e-10, -1.2464055693543441e-9, -4.245728660276272e-10, 1.593025381082978e-9, -4.459693078344771e-9, 4.028017141592466e-9, 8.403986715881173e-10, -4.0016122754593955e-9, -2.340611308737713e-9, -5.351235611500193e-9, 2.4742508656885748e-9, 2.893424863988426e-10, 2.8471288413057456e-9, -3.3694967253978302e-9, 2.860287464285405e-11, 3.263255546000445e-9, -1.336486846148743e-9]), layer_4 = (weight = [-0.000593053095930714 -0.0007946594169579043 -0.0006789138046269256 -0.0006736321597273989 -0.0006281499770533718 -0.0006467891417281563 -0.0008974849572661976 -0.000701503036152726 -0.0007072422866258231 -0.0007898494528730875 -0.0007271119206423325 -0.0006914105701331661 -0.0006074057060786563 -0.0005876197139559633 -0.0007210315600420105 -0.000867623769458736 -0.0006889821052483196 -0.0006304635523473107 -0.0006118547800492084 -0.0006695799333180675 -0.0007081599723512993 -0.0007078194346692411 -0.0006694896657174956 -0.0007046669213662709 -0.0005106648533400487 -0.0005813648655892891 -0.0006948434963697854 -0.0006385629122776955 -0.0008379525529626718 -0.0006568040321971684 -0.0007973301417994263 -0.0007651463406080634; 0.00013040215248582174 0.0001576685174348691 0.00028354157866503126 -6.69123637320766e-5 0.00029299038470762657 0.00030995034862888696 0.00018188146896295563 0.0003917712659638256 0.00034465959645637597 0.00035699766320800084 0.0001407266309207205 0.0004452091613707507 0.0002736753432331936 0.00025617155178832485 0.00022732622588095267 0.0001586955121837759 0.0003027267798654679 0.0002224942561958167 4.366473163511302e-5 0.00015791122210165318 0.0002075102269386806 0.000303752084435396 0.00022179113359256398 0.0002225390565672504 -3.1283714932263345e-5 0.0004111654974981046 0.00021307915584580859 0.0002161182229719744 0.00034972075114502967 0.0004429422498137358 0.00011839956395222277 0.00017044528487534442], bias = [-0.0007040653823325861, 0.00023761324165170615]))Visualizing the Results
Let us now plot the loss over time
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")
lines!(ax, losses; linewidth=4, alpha=0.75)
scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)
fig
endFinally let us visualize the results
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
s3 = scatter!(
ax,
tsteps,
waveform_nn_trained;
marker=:circle,
alpha=0.5,
strokewidth=2,
markersize=12,
)
axislegend(
ax,
[[l1, s1], [l2, s2], [l3, s3]],
["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
position=:lb,
)
fig
endAppendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.11.8
Commit cf1da5e20e3 (2025-11-06 17:49 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 4 × Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, icelake-server)
Threads: 4 default, 0 interactive, 2 GC (on 4 virtual cores)
Environment:
JULIA_DEBUG = Literate
LD_LIBRARY_PATH =
JULIA_NUM_THREADS = 4
JULIA_CPU_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0This page was generated using Literate.jl.