Training a Neural ODE to Model Gravitational Waveforms
This code is adapted from Astroinformatics/ScientificMachineLearning
The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl
Package Imports
using Lux,
ComponentArrays,
LineSearches,
OrdinaryDiffEqLowOrderRK,
Optimization,
OptimizationOptimJL,
Printf,
Random,
SciMLSensitivity
using CairoMakieDefine some Utility Functions
Tip
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector
function one2two(path, m₁, m₂)
M = m₁ + m₂
r₁ = m₂ / M .* path
r₂ = -m₁ / M .* path
return r₁, r₂
endNext we define a function to perform the change of variables:
@views function soln2orbit(soln, model_params=nothing)
@assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
if size(soln, 1) == 2
χ = soln[1, :]
ϕ = soln[2, :]
@assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
p, M, e = model_params
else
χ = soln[1, :]
ϕ = soln[2, :]
p = soln[3, :]
e = soln[4, :]
end
r = p ./ (1 .+ e .* cos.(χ))
x = r .* cos.(ϕ)
y = r .* sin.(ϕ)
orbit = vcat(x', y')
return orbit
endThis function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d_dt(v::AbstractVector, dt)
a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
b = (v[3:end] .- v[1:(end - 2)]) / 2
c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
return [a; b; c] / dt
endThis function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d2_dt2(v::AbstractVector, dt)
a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
return [a; b; c] / (dt^2)
endNow we define a function to compute the trace-free moment tensor from the orbit
function orbit2tensor(orbit, component, mass=1)
x = orbit[1, :]
y = orbit[2, :]
Ixx = x .^ 2
Iyy = y .^ 2
Ixy = x .* y
trace = Ixx .+ Iyy
if component[1] == 1 && component[2] == 1
tmp = Ixx .- trace ./ 3
elseif component[1] == 2 && component[2] == 2
tmp = Iyy .- trace ./ 3
else
tmp = Ixy
end
return mass .* tmp
end
function h_22_quadrupole_components(dt, orbit, component, mass=1)
mtensor = orbit2tensor(orbit, component, mass)
mtensor_ddot = d2_dt2(mtensor, dt)
return 2 * mtensor_ddot
end
function h_22_quadrupole(dt, orbit, mass=1)
h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
return h11, h12, h22
end
function h_22_strain_one_body(dt::T, orbit) where {T}
h11, h12, h22 = h_22_quadrupole(dt, orbit)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
h11 = h11_1 + h11_2
h12 = h12_1 + h12_2
h22 = h22_1 + h22_2
return h11, h12, h22
end
function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
# compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2
@assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"
h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
@assert mass_ratio ≤ 1 "mass_ratio must be <= 1"
@assert mass_ratio ≥ 0 "mass_ratio must be non-negative"
orbit = soln2orbit(soln, model_params)
if mass_ratio > 0
m₂ = inv(T(1) + mass_ratio)
m₁ = mass_ratio * m₂
orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
return waveform
endSimulating the True Model
RelativisticOrbitModel defines system of odes which describes motion of point like particle in Schwarzschild background, uses
where,
function RelativisticOrbitModel(u, (p, M, e), t)
χ, ϕ = u
numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
denom = sqrt((p - 2)^2 - 4 * e^2)
χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = numer / (M * (p^(3 / 2)) * denom)
return [χ̇, ϕ̇]
end
mass_ratio = 0.0 # test particle
u0 = Float64[π, 0.0] # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4) # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize) # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, eLet's simulate the true model and plot the results using OrdinaryDiffEq.jl
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)
axislegend(ax, [[l, s]], ["Waveform Data"])
fig
endDefining a Neural Network Model
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but in case you do use them, make sure to mark them as const.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,
const nn = Chain(
Base.Fix1(fast_activation, cos),
Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)((layer_1 = NamedTuple(), layer_2 = (weight = Float32[4.0107116f-5; 0.00011839577; 0.000112146154; 1.1913782f-5; 1.2286018f-6; 0.00013963722; -0.00015999372; -2.5821013f-5; -0.00018654538; 9.946304f-5; -3.0608462f-5; -6.4219465f-5; 0.00013065664; 5.9212016f-5; 7.702233f-5; 0.00011927293; 5.2953434f-5; 5.8055237f-5; -8.409264f-5; 4.9633094f-5; 7.53851f-5; -8.5163345f-5; -4.8293555f-6; -9.0732654f-5; -0.00010341135; -0.0001302775; -0.00013695072; -0.00012043256; -4.597264f-6; 9.32693f-5; -0.000105176594; -0.00010931184;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-1.0352158f-5 2.4302039f-5 0.00014694399 -5.3189033f-5 -5.6734607f-5 -0.000177417 8.385995f-5 -8.1326674f-5 7.637468f-5 -6.055813f-5 -8.9780835f-5 -8.405892f-5 2.4474193f-5 -0.00011929852 -0.0001604457 -8.5068634f-5 0.000111925954 7.880827f-5 0.0001019114 5.256327f-5 0.00016694977 3.334962f-5 5.0937775f-5 0.00012408578 -1.1804488f-5 1.700967f-5 0.0001781452 0.00014042649 -7.478883f-5 -0.000101765436 -6.82664f-5 9.648146f-5; 6.72855f-5 -3.243262f-5 -8.5298365f-5 3.4362118f-5 0.00013249411 -4.2082916f-5 0.0001332186 -3.0790557f-6 5.5979555f-5 2.9669003f-5 -7.071017f-5 2.0950045f-5 -5.580644f-5 7.390819f-5 -3.3738794f-5 -5.7983434f-6 -0.00013854407 2.102962f-5 -4.6319696f-5 -7.252596f-5 -9.090364f-6 9.26087f-5 -0.00010995063 0.00025699913 -0.00029260688 -9.2893824f-5 0.00010400283 6.2510124f-5 -0.00013536123 0.00017182439 3.855803f-5 -4.048743f-5; -8.089563f-7 0.0002619597 0.00018047963 9.232478f-5 -5.731391f-5 0.00015495611 0.00010258975 -5.4663546f-5 1.2076686f-6 -0.00012962463 -2.7189444f-5 0.00011772733 0.00013734699 -7.216963f-5 0.00010288832 0.00013627445 -5.7145676f-6 -1.4484738f-6 -2.019317f-5 -0.000120793404 -6.206388f-5 -0.000121721176 -0.00013018204 -0.00014626687 0.00018334776 8.0784725f-5 9.54846f-5 -0.00015669898 0.00010263685 4.0224113f-5 0.00018014784 1.6868167f-5; 1.7373508f-5 -0.0001327009 2.0953576f-5 -7.8842946f-5 6.735623f-5 4.270833f-6 -8.4401836f-5 -0.00010169429 -7.205578f-5 -0.00013435511 -8.764044f-5 -1.2035418f-5 1.1013733f-5 1.6467171f-5 -9.283538f-6 7.466536f-5 8.1151185f-5 -7.6236834f-6 1.8047711f-5 0.00014339753 -2.8425891f-5 -2.936881f-5 -9.178815f-5 1.4232907f-5 -0.0001027127 -0.00014576217 2.4118665f-6 0.00014197087 0.0001629096 -5.111559f-6 -8.42655f-5 0.00011323325; -0.00012383863 -8.778154f-5 0.00014065632 -8.260948f-5 -9.913621f-5 3.0247659f-6 6.340476f-5 -0.00012732958 -0.0001683238 -2.8213855f-5 -9.55213f-5 -5.963492f-5 7.623668f-6 9.711198f-6 1.325461f-5 -0.00010987714 -1.9262567f-5 7.88083f-5 -6.895058f-5 -0.00019732125 -9.438797f-5 -6.873734f-5 -1.8210541f-5 -3.889944f-5 6.0281207f-5 2.0949725f-5 0.00016735776 7.617961f-5 -0.00018171754 -6.9703274f-5 -0.00017090872 -0.00016138042; 0.00020177675 8.4483385f-5 -4.9565788f-5 -0.00015506522 6.170061f-6 5.513862f-5 9.116709f-6 7.0503054f-5 -9.248306f-5 1.5716183f-5 7.9427016f-5 7.161022f-5 -0.0001549202 6.905481f-7 -9.971656f-5 -0.000146517 6.722515f-5 0.00013134035 7.414122f-5 -7.869999f-5 0.0001206054 0.000117770694 8.438324f-5 4.9710972f-5 0.00018738046 -3.6725654f-5 -2.6962869f-5 0.00022907555 1.845211f-5 -4.7182733f-5 -0.00012841226 0.00023464614; -0.00015953404 6.7708585f-5 8.4351705f-5 5.9047507f-5 -0.00010630878 -0.00014576383 -9.01953f-5 4.8907506f-5 -4.2319007f-5 0.00013492617 4.969224f-5 -6.0621445f-5 4.9815728f-5 1.7069442f-5 -0.00010078585 6.3422056f-5 0.00022860817 2.6716973f-5 -4.7479796f-5 2.2586863f-5 -3.8116195f-5 9.572895f-5 -0.00013658384 -6.654594f-5 -7.8686964f-5 -0.00012492068 -6.44099f-5 7.680758f-6 -7.5279495f-5 0.00011249202 -4.6684632f-5 -3.5529025f-5; -0.00011609583 2.9102655f-5 8.362637f-5 4.578224f-6 -2.7128985f-5 8.877046f-5 -0.00010609909 7.865625f-5 0.0001467101 -2.017391f-5 -7.087621f-5 -0.000113470385 9.061652f-5 -0.00014413781 -6.647633f-5 -0.00019426717 3.8161255f-5 -1.4612503f-5 7.052702f-5 -4.7203134f-6 -6.1274426f-5 -8.162758f-5 -1.1559061f-5 -9.335145f-5 7.006234f-5 -4.101571f-6 1.6320135f-5 -8.5581705f-6 0.00022627939 0.00011463821 -9.116198f-6 4.8161135f-5; 4.0578034f-5 2.5719062f-5 1.9387478f-6 -9.542394f-6 -0.00013068064 0.00014110052 -1.3209244f-5 4.2974087f-5 -0.00022586656 -5.0552186f-5 -1.4896854f-5 0.00010003486 -4.2015192f-5 -6.6981556f-5 -2.2372798f-5 -0.00012712476 4.4766115f-5 -0.00010154685 -3.498441f-5 2.7300322f-5 0.00016522295 -9.223339f-6 -8.02018f-5 7.88984f-5 -0.00014788605 4.874141f-5 0.00014342528 -1.3435458f-5 -1.1807257f-5 -0.00015116287 1.7018097f-5 -0.00011798909; 2.0483325f-5 -0.00011882958 0.00015670275 6.643606f-5 0.00018145701 -0.00012163668 0.0002496027 -6.049714f-5 5.131171f-5 -0.00017548048 1.5837599f-5 -7.6592434f-5 7.782346f-5 -8.2571634f-5 -1.52012935f-5 -1.0681925f-5 3.4923807f-5 -0.00013690384 -1.3979997f-5 -0.0001537807 -0.00017570912 0.00016829884 -3.412318f-5 5.0837676f-5 6.395895f-5 -0.00016553217 -1.3278275f-5 5.1724564f-5 5.8751706f-5 -8.077468f-6 -5.4832173f-5 9.029467f-5; 5.4354343f-5 -7.161164f-5 0.00019486048 2.8624185f-5 0.00012464644 -2.1473574f-5 -9.136979f-5 -0.000101184705 -6.3106345f-5 6.923454f-5 -0.00025097092 5.32686f-5 -5.0441868f-5 -4.1922794f-5 2.2896676f-5 -3.0451192f-5 4.8034304f-5 0.00012077061 0.00011437078 -0.00013807775 -0.00013266817 4.5265886f-5 3.4884408f-5 0.00011643436 2.3737792f-5 -1.0092512f-5 5.3440723f-5 -7.137759f-5 0.00019242331 5.1624032f-5 -0.00029361108 0.00012388191; 3.109804f-5 -2.6493193f-5 1.3155575f-5 6.1039436f-6 7.294311f-5 -7.8185716f-5 0.00013230667 1.4627225f-5 0.000107259446 -7.3686286f-5 -0.00010562334 6.777978f-5 0.00012494708 -3.4424254f-5 4.639607f-5 -0.00018846589 -0.00016985764 -3.9727074f-6 1.4899846f-5 0.00025419221 6.740473f-6 2.8893151f-5 -4.03708f-5 -5.3120148f-5 -5.439358f-5 3.2866792f-5 -0.00018056636 -4.056984f-5 9.4891095f-7 5.3445146f-5 0.00019735741 -5.485132f-5; 1.7060574f-5 -0.00011950039 -1.2674013f-5 -4.2491836f-5 7.221089f-6 1.0023075f-5 -4.8185833f-5 0.00016170865 6.327126f-5 7.4840165f-5 2.4926365f-5 0.0002479329 5.281547f-5 9.831631f-5 4.5192337f-5 9.476162f-6 -0.00021255898 -0.000121319936 0.0001345376 5.5029348f-5 -0.00012808334 -3.0074692f-5 3.1805783f-5 -7.100853f-5 -0.00020447803 5.7922127f-5 8.7384535f-5 1.3798489f-5 -0.00010477341 -3.983182f-8 1.0104029f-5 1.8142888f-5; 0.00021047216 4.153958f-5 -4.9278415f-6 -5.3644966f-5 0.00019299614 2.732532f-5 1.1523402f-5 3.3261607f-5 -1.2383836f-5 -6.770577f-5 -9.0316506f-5 1.9669906f-5 -4.604646f-5 -0.00014952946 8.4885745f-5 -3.655134f-5 -0.00015861916 -1.5022062f-5 -0.00015060104 7.5178716f-5 -0.0001017425 -0.00013362282 7.1700873f-7 -0.00022286946 0.00021854701 -2.908124f-5 -3.0774878f-5 2.0956431f-5 -3.918909f-5 -0.00012062536 0.00010717431 3.786886f-5; -0.00020993447 -0.000117014315 -5.5592795f-5 6.940443f-5 5.541149f-5 5.8051417f-5 -9.0531226f-5 -2.5369734f-6 1.9020163f-5 3.565866f-5 -9.9956436f-5 5.3753593f-5 0.00022280005 9.4458745f-5 -9.399009f-5 -5.063655f-5 2.1054915f-5 -0.00010535359 0.000102181475 0.00013439084 0.000105791085 -0.00016309698 2.4606872f-5 0.00016149011 -8.812801f-5 0.00012261442 -0.00010712435 2.3707082f-6 0.00018764692 -1.27968315f-5 -5.206039f-5 5.4363027f-5; -4.67794f-6 0.00011309119 -5.2437023f-5 0.00027085585 -8.197705f-5 -2.1201402f-5 -0.00014993214 -5.51608f-5 0.00013208526 -4.1076462f-5 9.17404f-5 -9.13221f-5 -7.9382364f-5 6.138551f-5 -2.971157f-5 -3.105124f-7 0.00016986877 0.00023119863 -3.262959f-5 0.00015762744 -4.7534097f-5 5.199446f-5 7.060306f-5 -2.7432146f-5 -8.367443f-6 0.00022792448 -8.516915f-5 2.4588502f-5 -4.7242793f-5 -2.884648f-5 1.6244965f-5 -0.00021973345; -0.000199566 7.472606f-6 -0.000112768364 -9.8138575f-5 -9.0904636f-5 2.682728f-5 1.87016f-5 -9.154229f-5 -7.120547f-5 -6.19264f-5 -9.0092806f-5 0.0002017106 -6.9065485f-5 -0.00018202787 -0.00014845605 -8.239688f-5 -7.734781f-5 0.00017922632 5.5855576f-6 0.00012234676 -0.000118348864 -3.4058685f-5 1.3139299f-5 4.3208f-5 0.00010280965 -2.7490953f-5 1.3338237f-5 -8.038285f-6 -0.00015925673 0.00010058205 0.00010312563 -5.2361193f-5; 0.00023988776 0.0001455354 -9.322263f-5 9.256853f-5 4.1028565f-5 5.6505207f-5 0.00014389657 3.0020059f-5 -3.6693546f-6 4.2054704f-5 -5.837514f-5 -2.2202174f-5 0.00014260749 5.9549107f-5 -4.4489723f-5 0.00013682208 7.752274f-5 3.262159f-5 -8.053588f-5 -8.16111f-5 -3.1770986f-5 -2.7553411f-5 -9.4322015f-5 7.5547864f-6 3.2974833f-5 -6.949981f-5 -0.00011270414 0.00015185442 3.8633952f-5 2.3498125f-5 0.00012559706 0.000110220455; -8.8545414f-5 0.00014226574 -4.4329136f-5 -4.6169814f-7 1.9520176f-6 -0.00024145134 0.00022327321 -9.017307f-5 -5.694143f-5 -7.310187f-5 -3.889559f-5 0.00015160018 -1.3695051f-5 7.088645f-5 -2.422635f-5 4.459114f-5 0.00019993488 0.00017078931 -5.8516078f-5 5.3256168f-5 -0.0001327125 -3.3291642f-5 -0.0001864729 5.8920577f-5 -9.517688f-5 6.524454f-5 -0.000116500465 2.2214075f-5 -0.00013849704 0.00010551419 5.4356773f-5 -3.752116f-5; 4.456317f-5 7.19416f-5 -0.00011054756 -4.3035194f-5 -2.881295f-5 0.00020735105 -1.660735f-5 5.805329f-5 8.474395f-5 7.818738f-5 2.9399855f-5 9.214284f-5 -1.2516082f-5 -8.687111f-5 5.925379f-5 5.9823356f-5 -9.498351f-5 -3.118052f-5 -6.830728f-5 -0.0003533587 -3.8115995f-6 1.2974253f-5 -0.00012399531 -4.2185384f-5 -6.739884f-5 3.7720863f-5 -7.652796f-5 -9.896152f-5 8.550554f-5 -1.8917313f-5 1.5469475f-5 0.000110048495; -0.00010544594 9.6371594f-5 4.698164f-5 -2.9569568f-5 6.776944f-5 -3.496683f-5 -5.3735832f-5 -6.010103f-5 -3.0749205f-5 -9.177593f-6 4.724772f-5 -8.276473f-5 3.283473f-6 6.478737f-5 -0.00011725398 -5.538977f-5 -0.00010829706 4.172958f-5 0.0001625093 6.6688415f-5 0.00011225324 8.448889f-5 3.958131f-5 1.5033424f-5 -7.6993325f-5 3.0105593f-5 -4.4150922f-5 -5.6181405f-5 1.6317f-5 -9.844582f-6 0.00013608931 -0.000103872844; -2.0859397f-5 -7.755359f-5 2.9558885f-6 -0.0001586348 0.00013464643 -1.8250876f-5 -8.655665f-5 -1.7056798f-5 -0.00017316961 9.960709f-5 -2.2242515f-5 -1.3155726f-5 -8.637155f-5 2.8862307f-5 1.8398163f-5 5.2439323f-6 -0.00010236003 4.5246623f-5 -0.00013020374 -5.1849667f-5 -0.00021667001 1.5959991f-5 -5.753719f-5 -6.094642f-5 -8.651207f-5 7.873156f-5 5.5689125f-5 7.162725f-5 -2.427785f-5 -0.00026176622 2.3780574f-5 -8.163628f-5; 3.111611f-5 0.00013975364 -0.00015616916 -5.9151742f-5 5.5949993f-5 2.6980102f-5 -0.000109753244 1.1703648f-5 -8.2909864f-5 -2.9382206f-6 1.7254135f-5 -8.4475294f-5 -4.314307f-5 5.15488f-5 0.00011536688 3.1711563f-5 0.00015653048 0.000100724494 3.9064813f-5 0.00022733887 -6.454958f-5 7.92341f-5 -2.3188052f-6 4.8724767f-5 4.6664863f-5 0.0001561579 -8.5609805f-5 -2.8426339f-5 -1.9251305f-5 6.90085f-7 0.00012272275 -0.00017011438; 0.00011202433 -2.4427121f-5 -0.00014840279 -9.269977f-5 -1.8409104f-5 9.7157f-5 -4.1535342f-5 6.404877f-5 -8.180667f-5 -7.2748367f-6 -0.0001252743 -9.4156086f-8 6.56251f-6 -0.0002023792 -0.00011253721 -4.667545f-6 -1.1724512f-5 -0.00013492997 7.3815618f-6 0.00011044025 5.1718675f-5 -0.00025442187 0.00014323018 5.1307583f-5 -3.3764223f-5 -0.000110678506 0.0001138986 -5.8850874f-5 5.584862f-5 0.0001398086 4.3724896f-5 8.4868756f-5; 3.4550627f-5 4.163493f-6 -0.00014264905 -4.7968635f-5 5.428645f-5 -7.077405f-5 1.6703412f-5 0.00013281978 5.8035926f-5 -7.098161f-5 8.5422536f-5 0.0002080063 -1.8122993f-5 -0.0001189111 -0.00015139265 -6.747092f-6 -5.6978977f-5 9.4649045f-7 9.722204f-5 -1.5926676f-6 9.940381f-5 9.8895725f-5 -0.00015123497 6.403526f-5 6.734132f-5 2.8592605f-5 -7.042711f-5 6.0116046f-5 3.7344147f-5 -2.4078383f-5 -0.00015423191 2.9466255f-5; -2.5308025f-5 -0.00011136752 -5.1606632f-5 -3.8568003f-5 -0.00011866767 9.136888f-6 0.0001545827 4.5326004f-5 -9.736322f-5 5.8759615f-5 -9.792352f-5 0.00013895499 5.921275f-5 5.2330975f-5 0.00010086706 8.192273f-5 -2.2374605f-5 -1.8827079f-5 5.107477f-5 0.00024999995 -0.0001478935 4.316398f-5 -6.012086f-5 0.00013618566 0.00012632378 -9.26978f-5 0.00014897987 -2.2104223f-5 -9.4694085f-5 -6.544098f-5 8.44911f-5 -0.000112796435; -3.0148762f-5 9.204029f-7 0.00010468495 7.720962f-5 3.2097792f-5 -3.7906324f-5 -1.2783032f-5 -3.3317278f-6 -1.6051288f-5 -2.2627006f-5 -2.505019f-5 3.6048412f-5 5.148231f-5 9.27637f-5 8.605498f-5 0.000115346294 -8.571816f-5 -4.3890333f-5 -6.1692095f-5 -3.263295f-5 7.598974f-5 9.546094f-5 -5.4446107f-5 -5.3462783f-5 3.3249038f-5 0.00016553041 -9.266538f-5 8.906013f-5 7.582831f-5 1.1709867f-6 1.9110857f-5 -8.405858f-5; -0.00021178172 -7.073714f-5 -4.7246394f-5 -5.7869267f-5 -5.5054537f-5 -5.3881085f-5 -9.2667666f-5 8.971291f-5 8.368963f-5 -0.000106011095 4.01492f-5 -8.984889f-5 -0.00016132512 3.4493853f-5 0.00014446117 0.00018237003 3.391811f-5 6.3296764f-5 -5.8011745f-5 1.3154549f-5 0.00011741025 9.573174f-5 -3.8508435f-5 -4.4876284f-5 2.90491f-5 -4.4695393f-5 -5.923613f-5 -7.0728776f-5 2.2881966f-5 1.9842824f-5 1.8309092f-5 1.807441f-6; 0.00011317905 5.2686213f-5 5.1167837f-5 -1.1699707f-5 -9.887118f-5 -9.296346f-5 -7.687699f-5 -6.839256f-5 -3.4181856f-5 4.50458f-6 0.00025197794 3.039594f-6 -0.0001248594 8.743733f-5 0.00011556446 -7.764902f-5 -4.7560738f-5 0.00020977359 6.543392f-5 5.382772f-5 -0.00010289941 7.029479f-5 4.56251f-5 1.2872039f-6 2.6711152f-5 -7.340382f-5 -2.6474518f-6 -2.5995512f-5 0.00015795352 5.0064642f-5 -9.4389274f-5 -5.9561913f-5; -0.0001810973 -0.00011951741 -2.6533138f-5 0.00015989544 0.00015743694 -0.0002303659 7.422532f-5 4.859885f-6 5.1720417f-5 0.00015639793 8.238855f-6 1.3164347f-5 1.2817633f-5 -0.00011480376 4.7093054f-5 -0.00013763938 -4.460221f-5 -9.355994f-5 9.79105f-5 5.6742367f-5 0.00029477105 0.00025073678 -0.00024675898 -0.00010216876 1.6508518f-5 0.000121421814 -0.000113900125 -3.5503977f-5 -0.00018381905 -6.989776f-5 0.00013271136 -0.00010259964; -1.1498165f-5 0.00011088277 9.1124175f-6 -0.00012078173 1.2339167f-5 -2.3592906f-6 2.3362456f-6 8.9000685f-5 -9.654823f-5 -9.181252f-5 -5.3240507f-5 -0.00018504162 0.00012699839 0.00013092144 4.923925f-5 2.4912815f-5 3.5063364f-5 0.0001127381 8.189145f-5 -5.127144f-5 -0.000105973966 0.0001044213 2.6088062f-5 4.002568f-5 0.00014982314 8.8075925f-5 6.8366695f-5 1.9087882f-5 3.0483418f-5 -0.00012547956 -0.000111971975 -0.00012178087; -0.00016848615 -9.830753f-5 -6.4234715f-5 8.545441f-5 -3.580178f-5 -1.572408f-5 -2.2076347f-5 -2.880194f-5 -0.000274021 0.0001166305 -0.00011441378 7.690109f-6 8.376987f-5 -0.00014787602 6.9753194f-5 -7.44024f-5 -3.1988642f-5 3.9395454f-5 5.397425f-6 -3.87741f-5 -4.6320245f-5 -1.8933499f-5 -0.00017579632 -2.1457538f-5 3.235186f-5 -3.820098f-5 -8.270532f-6 0.00015183128 2.5460122f-5 -0.000104379695 0.0001302591 2.333236f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[6.3005165f-5 -8.9612135f-5 -4.752226f-5 2.3805198f-5 -3.1959327f-5 -0.00015474083 -2.3611457f-5 -0.00018046799 6.296462f-5 7.084806f-5 -0.000113507194 -7.10127f-5 3.836914f-5 -0.00012105005 7.599567f-5 -0.00010718294 2.2612767f-5 0.00013174284 -6.8262048f-6 0.00012219354 0.00025993938 -5.8354075f-5 0.00011375455 0.00036651018 0.00011028372 7.26227f-5 -0.00016737546 -0.000116878036 -9.8734f-5 -3.0898045f-5 -7.0702217f-6 -4.5546676f-5; 0.00020292184 7.6422744f-5 -4.4158794f-5 -9.033161f-5 9.7402386f-5 0.00017600098 -0.00016941743 -0.00013419914 -1.8114607f-6 -2.5375322f-5 0.000108479166 -9.896208f-5 3.3694858f-5 7.687223f-5 -0.00020284652 -8.989097f-5 7.600112f-5 -9.896048f-5 7.288587f-5 -4.2626078f-5 0.00014215277 5.7473433f-5 2.7223474f-5 1.645938f-6 -3.1501702f-5 0.00015505856 -5.2873154f-5 5.498173f-5 0.00013751908 6.612554f-5 3.618017f-5 1.4236407f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))Similar to most deep learning frameworks, Lux defaults to using Float32. However, in this case we need Float64
const params = ComponentArray(f64(ps))
const nn_model = StatefulLuxLayer(nn, nothing, st)StatefulLuxLayer{Val{true}()}(
Chain(
layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
layer_2 = Dense(1 => 32, cos), # 64 parameters
layer_3 = Dense(32 => 32, cos), # 1_056 parameters
layer_4 = Dense(32 => 2), # 66 parameters
),
) # Total: 1_186 parameters,
# plus 0 states.Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses
where,
function ODE_model(u, nn_params, t)
χ, ϕ = u
p, M, e = ode_model_params
# In this example we know that `st` is am empty NamedTuple hence we can safely ignore
# it, however, in general, we should use `st` to store the state of the neural network.
y = 1 .+ nn_model([first(u)], nn_params)
numer = (1 + e * cos(χ))^2
denom = M * (p^(3 / 2))
χ̇ = (numer / denom) * y[1]
ϕ̇ = (numer / denom) * y[2]
return [χ̇, ϕ̇]
endLet us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
axislegend(
ax,
[[l1, s1], [l2, s2]],
["Waveform Data", "Waveform Neural Net (Untrained)"];
position=:lb,
)
fig
endSetting Up for Training the Neural Network
Next, we define the objective (loss) function to be minimized when training the neural differential equations.
const mseloss = MSELoss()
function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
return mseloss(pred_waveform, waveform)
endWarmup the loss function
loss(params)0.000732273139810515Now let us define a callback function to store the loss over time
const losses = Float64[]
function callback(θ, l)
push!(losses, l)
@printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
return false
endTraining the Neural Network
Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
optprob,
BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
callback,
maxiters=1000,
)retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [4.010711563741982e-5; 0.0001183957720056412; 0.00011214615369681743; 1.1913782145695781e-5; 1.2286018318262268e-6; 0.0001396372244925763; -0.00015999372408251577; -2.5821012968645955e-5; -0.00018654538143856296; 9.946303907785512e-5; -3.060846211152681e-5; -6.421946454792976e-5; 0.00013065664097656224; 5.921201591254605e-5; 7.70223268773206e-5; 0.00011927293235190367; 5.2953433623716536e-5; 5.805523687738654e-5; -8.409263682542029e-5; 4.963309402224395e-5; 7.538509817088257e-5; -8.516334492007809e-5; -4.829355475518759e-6; -9.073265391629208e-5; -0.00010341135202896384; -0.00013027750537720535; -0.00013695072266288645; -0.00012043255992466692; -4.597263796306228e-6; 9.326930012312307e-5; -0.00010517659393367593; -0.00010931184078789432;;], bias = [1.0019519172263451e-16, 2.82102861014968e-16, -5.126967814817518e-18, 1.7548715624101302e-17, 1.6694117385068154e-19, 4.665668997915915e-17, -8.897058240322315e-17, -1.870510470795611e-17, -3.914612719381545e-16, -4.144608555460758e-17, -3.493124775118815e-17, -3.1119327691584313e-18, 1.1653996520096843e-16, 9.782468821272434e-17, 2.1941794180210976e-17, 1.4190510522806725e-16, 5.85944813366036e-17, 2.338070279292716e-17, -1.0542856433057057e-16, 4.2890243576245366e-17, 7.939707734634338e-17, -7.445414472227158e-17, -3.45504039954018e-18, -6.064501952986029e-17, -1.1729034260967902e-16, -1.147216634535022e-16, 1.7003612158811977e-16, 2.1466787332354507e-18, -7.680598119137707e-18, 9.741482368096388e-17, -4.921612144988719e-17, -1.1032419210442921e-16]), layer_3 = (weight = [-1.0351067717734552e-5 2.430312877954681e-5 0.00014694507870849504 -5.318794257213572e-5 -5.673351671752725e-5 -0.00017741590392284057 8.386104189816686e-5 -8.132558406536698e-5 7.637576953533536e-5 -6.055703971974175e-5 -8.977974477197282e-5 -8.40578298407735e-5 2.4475283393665534e-5 -0.00011929743169780219 -0.0001604446163724311 -8.506754358272776e-5 0.0001119270443127367 7.880935907856237e-5 0.00010191248901214928 5.2564361507805855e-5 0.00016695086258955826 3.335071113219955e-5 5.0938865283073666e-5 0.00012408687206011702 -1.1803397970242663e-5 1.701076024296356e-5 0.00017814629125900632 0.00014042758146792867 -7.478774324731177e-5 -0.00010176434563213539 -6.826531290748752e-5 9.648255098404667e-5; 6.728561410844867e-5 -3.243250912938703e-5 -8.529825291199108e-5 3.436222968438255e-5 0.0001324942232661563 -4.208280452595549e-5 0.00013321870491771485 -3.0789437057244668e-6 5.5979667280814155e-5 2.9669115150943847e-5 -7.071005975895949e-5 2.0950156880906507e-5 -5.580632651409582e-5 7.390829986824258e-5 -3.373868176432733e-5 -5.798231455366768e-6 -0.0001385439581151439 2.102973221034346e-5 -4.631958373075762e-5 -7.252584774115518e-5 -9.090252120324275e-6 9.260881333310677e-5 -0.00010995052153312986 0.0002569992458068445 -0.00029260677266003057 -9.289371237692868e-5 0.00010400294112904474 6.25102357792345e-5 -0.0001353611195606497 0.00017182449946111589 3.855814347324467e-5 -4.048731798840184e-5; -8.05470035165571e-7 0.00026196319341169794 0.0001804831201347922 9.232826524158747e-5 -5.731042510386733e-5 0.00015495959924509244 0.00010259323717535212 -5.466005931747398e-5 1.211154840774991e-6 -0.00012962114742049374 -2.7185957873815698e-5 0.00011773081605998158 0.0001373504721464007 -7.216614066646074e-5 0.00010289180609607224 0.00013627793778641238 -5.711081298151081e-6 -1.4449874874952608e-6 -2.0189684432349567e-5 -0.00012078991748648219 -6.206039578909869e-5 -0.00012171768939378731 -0.0001301785585333052 -0.0001462633870637527 0.00018335124441864456 8.078821083972009e-5 9.548808266616937e-5 -0.00015669549501990844 0.00010264033444901748 4.022759931604155e-5 0.0001801513219156719 1.687165350221906e-5; 1.7372227262507962e-5 -0.00013270217470972027 2.09522949039635e-5 -7.88442267771697e-5 6.735495105703557e-5 4.269552278404105e-6 -8.440311660205627e-5 -0.00010169557219751885 -7.205706071953524e-5 -0.000134356392777137 -8.76417180957156e-5 -1.2036698648902643e-5 1.1012451864645963e-5 1.646589032634037e-5 -9.28481870187471e-6 7.466407976198794e-5 8.11499047590544e-5 -7.624964041046465e-6 1.8046430227966276e-5 0.0001433962460173417 -2.8427171968818052e-5 -2.9370090600880495e-5 -9.178942866538455e-5 1.4231626551461643e-5 -0.00010271398070881436 -0.00014576344994975295 2.4105858078678383e-6 0.00014196959080027275 0.00016290831251853345 -5.112839630330252e-6 -8.426677970828768e-5 0.00011323196656621753; -0.00012384426834807083 -8.778717627327776e-5 0.00014065068800655977 -8.261511273496406e-5 -9.914184655109998e-5 3.019129660897966e-6 6.339912480865039e-5 -0.00012733521460369683 -0.0001683294295453821 -2.821949154358529e-5 -9.552693980560115e-5 -5.964055624836882e-5 7.618031684538309e-6 9.705561676936362e-6 1.3248973953426256e-5 -0.00010988277525219915 -1.9268202851992778e-5 7.88026617718864e-5 -6.895621944031897e-5 -0.0001973268855018554 -9.439360754377682e-5 -6.874297567456215e-5 -1.8216177597531272e-5 -3.890507749211052e-5 6.0275570756769485e-5 2.0944088571846497e-5 0.0001673521282263384 7.617397470122406e-5 -0.0001817231723927019 -6.970890997954567e-5 -0.00017091435900701637 -0.00016138605595378465; 0.0002017801820719353 8.448681223202223e-5 -4.9562360397578396e-5 -0.00015506179584658358 6.1734885483504165e-6 5.514204571484636e-5 9.120136785733696e-6 7.050648168192534e-5 -9.247963276465222e-5 1.571961042176655e-5 7.9430443695721e-5 7.161364960015654e-5 -0.00015491677145943657 6.939754988242913e-7 -9.971313516612466e-5 -0.000146513578835915 6.722857546524195e-5 0.00013134377554056218 7.414464966390786e-5 -7.869656070138243e-5 0.00012060882767659338 0.00011777412186606338 8.438666595143626e-5 4.9714399613406526e-5 0.00018738388502501053 -3.6722226975296134e-5 -2.6959441636035777e-5 0.00022907897841181282 1.8455537831735665e-5 -4.7179306017967446e-5 -0.0001284088280564541 0.0002346495679765563; -0.00015953544896734527 6.770717920330327e-5 8.435029923743478e-5 5.904610117608032e-5 -0.00010631018192913108 -0.00014576523352120778 -9.019670795355426e-5 4.8906101090268496e-5 -4.23204120871054e-5 0.00013492476756977643 4.9690834946831137e-5 -6.0622850634953985e-5 4.981432249946634e-5 1.7068036702712673e-5 -0.00010078725716271828 6.34206506383138e-5 0.0002286067672848183 2.6715567235387476e-5 -4.7481201529133715e-5 2.258545720718722e-5 -3.8117600690437754e-5 9.572754681170589e-5 -0.0001365852451108927 -6.654734277804667e-5 -7.868836945095274e-5 -0.00012492208878216223 -6.441130352145803e-5 7.679352909640857e-6 -7.528090025704077e-5 0.00011249061828306612 -4.668603758732945e-5 -3.55304302338988e-5; -0.00011609598831288287 2.910249432883517e-5 8.362621057298317e-5 4.578062918734371e-6 -2.71291456448684e-5 8.877029805429601e-5 -0.00010609925183249284 7.865609123822525e-5 0.00014670993451713056 -2.0174070494395776e-5 -7.087637012606262e-5 -0.00011347054631929605 9.061636129612411e-5 -0.00014413797279136018 -6.647648941740144e-5 -0.00019426732714069713 3.8161094300302216e-5 -1.4612664097659256e-5 7.052686024675401e-5 -4.720474487986661e-6 -6.12745871768867e-5 -8.162773982636374e-5 -1.1559222082870686e-5 -9.335161402137688e-5 7.006218121372439e-5 -4.101732054696607e-6 1.6319973757926744e-5 -8.558331581444332e-6 0.00022627922490922953 0.00011463805177467208 -9.116359331674904e-6 4.816097399438082e-5; 4.0576045123793176e-5 2.5717073570179308e-5 1.9367591201280284e-6 -9.544382535047367e-6 -0.00013068262576528182 0.00014109853275184544 -1.3211232341935887e-5 4.297209800167694e-5 -0.00022586855136724114 -5.0554174392545886e-5 -1.4898842507342305e-5 0.00010003286790631686 -4.2017180546476865e-5 -6.698354498187143e-5 -2.237478709250599e-5 -0.00012712674886389944 4.4764126344284814e-5 -0.00010154884019104836 -3.498639723804616e-5 2.729833379160016e-5 0.00016522096613951103 -9.225327246187373e-6 -8.020378534386221e-5 7.889641476392984e-5 -0.00014788803907978887 4.8739422184357924e-5 0.00014342328852106793 -1.3437446411641453e-5 -1.1809245349674567e-5 -0.00015116485390282897 1.7016108743319743e-5 -0.00011799107831140507; 2.0483004906575987e-5 -0.00011882989923235568 0.00015670243016445459 6.643574312563306e-5 0.0001814566931594281 -0.00012163700005969513 0.00024960236629938283 -6.049745928378519e-5 5.131138829399507e-5 -0.00017548080353065046 1.5837278685078748e-5 -7.659275427925109e-5 7.782314066078008e-5 -8.257195434764673e-5 -1.520161361752466e-5 -1.0682245313933352e-5 3.492348702509235e-5 -0.0001369041564734531 -1.39803168471428e-5 -0.000153781016645941 -0.00017570944322271896 0.0001682985167158893 -3.412349900229584e-5 5.08373560174495e-5 6.395862880397014e-5 -0.0001655324885481833 -1.3278595300493996e-5 5.172424431891579e-5 5.87513860105655e-5 -8.077788091601207e-6 -5.483249327282988e-5 9.029435029021837e-5; 5.435474800524489e-5 -7.161123589266998e-5 0.0001948608819080614 2.8624589919310276e-5 0.00012464684319329664 -2.1473168584090196e-5 -9.136938255633355e-5 -0.00010118430013189823 -6.310593975595445e-5 6.923494716520501e-5 -0.00025097051051798085 5.3269004144204464e-5 -5.0441462502386575e-5 -4.192238933992409e-5 2.289708159569317e-5 -3.045078714731127e-5 4.803470934039155e-5 0.00012077101333179675 0.00011437118287776714 -0.0001380773488048353 -0.0001326677616301817 4.526629114170673e-5 3.488481293661729e-5 0.00011643476814842885 2.3738196847807374e-5 -1.0092106589037131e-5 5.344112783550247e-5 -7.137718288813699e-5 0.000192423712593701 5.162443763456403e-5 -0.000293610678039295 0.00012388231467102707; 3.109844047238816e-5 -2.6492792217060222e-5 1.3155976289534898e-5 6.104344634050669e-6 7.294351317349608e-5 -7.8185315119219e-5 0.00013230706909333743 1.4627626028496404e-5 0.0001072598469409839 -7.368588475847957e-5 -0.00010562293673038709 6.778018078429356e-5 0.00012494748162233767 -3.442385340796112e-5 4.6396472412279686e-5 -0.00018846548400225515 -0.0001698572369558948 -3.972306417299221e-6 1.490024706535201e-5 0.0002541926159256858 6.740874055221824e-6 2.8893552466879315e-5 -4.037039897858381e-5 -5.311974663974107e-5 -5.4393177465418147e-5 3.28671930265849e-5 -0.00018056596226862157 -4.0569440075079273e-5 9.49311952436582e-7 5.344554747815847e-5 0.00019735781179869188 -5.4850918872863017e-5; 1.7061080099678367e-5 -0.0001194998877035579 -1.2673507670104194e-5 -4.2491330575999805e-5 7.221594798479024e-6 1.0023581082360466e-5 -4.818532765783877e-5 0.00016170915837575577 6.32717673870576e-5 7.484067096085036e-5 2.4926870493813298e-5 0.00024793341795158424 5.2815976188877646e-5 9.831681741344593e-5 4.519284261915646e-5 9.476667357386595e-6 -0.00021255847609426032 -0.00012131943007988071 0.00013453809874773752 5.502985364002046e-5 -0.00012808283285993335 -3.0074186768603733e-5 3.180628841802386e-5 -7.100802399507949e-5 -0.00020447752108156014 5.792263250624394e-5 8.738504034409836e-5 1.37989945433326e-5 -0.00010477290565903388 -3.93261859866845e-8 1.0104534296271948e-5 1.814339334670169e-5; 0.0002104706845599135 4.1538100940311876e-5 -4.929321239974018e-6 -5.3646445867708464e-5 0.00019299466382981218 2.732384107694908e-5 1.1521921787436652e-5 3.326012754732716e-5 -1.238531605884683e-5 -6.770724850902847e-5 -9.031798528703872e-5 1.966842590227397e-5 -4.6047941414501355e-5 -0.0001495309361602262 8.488426554286147e-5 -3.655282036597216e-5 -0.00015862064448847363 -1.502354190997676e-5 -0.0001506025246457262 7.517723586679497e-5 -0.000101743981777325 -0.00013362429583175752 7.155289650643059e-7 -0.00022287094454477208 0.00021854552776822344 -2.9082720193757564e-5 -3.0776357580536e-5 2.0954951632671854e-5 -3.919056964187008e-5 -0.00012062683801595693 0.00010717283019367147 3.786738032395536e-5; -0.0002099333436194367 -0.00011701318759830438 -5.559166708505431e-5 6.940555579834417e-5 5.541261601368701e-5 5.805254443887654e-5 -9.053009886654282e-5 -2.5358459583592957e-6 1.9021290092970644e-5 3.565978807503737e-5 -9.995530887632861e-5 5.375471994969649e-5 0.000222801179581919 9.44598720093691e-5 -9.398896542503907e-5 -5.063542300146782e-5 2.1056042562608433e-5 -0.00010535246143942086 0.00010218260252489257 0.0001343919644864713 0.00010579221237324709 -0.00016309585077348623 2.4607999541818058e-5 0.00016149123951666495 -8.81268791704031e-5 0.00012261555060152088 -0.00010712322579590545 2.371835655621634e-6 0.00018764804520518321 -1.2795704030821072e-5 -5.2059262422943846e-5 5.4364154159457684e-5; -4.675916747720932e-6 0.00011309321477288849 -5.2434999210745056e-5 0.0002708578692048553 -8.197502696441964e-5 -2.119937865667866e-5 -0.0001499301136706426 -5.515877577147969e-5 0.0001320872840236411 -4.107443904525794e-5 9.174242647110874e-5 -9.132007774918158e-5 -7.938034044367924e-5 6.138753148240258e-5 -2.970954604701074e-5 -3.0848908376425415e-7 0.00016987079349325312 0.00023120065513045765 -3.2627565210200946e-5 0.000157629460468677 -4.753207334455021e-5 5.1996482093604066e-5 7.060508566189078e-5 -2.743012313170444e-5 -8.365419793916065e-6 0.00022792650330789662 -8.516712781296882e-5 2.459052544996579e-5 -4.7240769467526265e-5 -2.8844456514558984e-5 1.6246988402928324e-5 -0.00021973143096569323; -0.0001995693317999262 7.469272658008054e-6 -0.00011277169766025736 -9.814190785755928e-5 -9.090796889859791e-5 2.6823946553591388e-5 1.8698266229442107e-5 -9.154562654794659e-5 -7.120880114731902e-5 -6.19297360221896e-5 -9.009613937587915e-5 0.00020170726851438079 -6.906881832172729e-5 -0.00018203120349562886 -0.00014845938209100528 -8.240021348819079e-5 -7.735114273737422e-5 0.00017922298787970855 5.582224401679006e-6 0.0001223434220614722 -0.00011835219707927596 -3.406201800453236e-5 1.31359657993859e-5 4.320466827130292e-5 0.00010280631899013797 -2.7494285733598465e-5 1.3334904122998346e-5 -8.04161833138116e-6 -0.0001592600609560644 0.00010057871908293501 0.00010312229927740712 -5.23645256739594e-5; 0.00023989189659649496 0.0001455395483545234 -9.321849054952686e-5 9.257267239907942e-5 4.103270631920747e-5 5.650934846658672e-5 0.00014390071166148948 3.00242006388411e-5 -3.6652129737961394e-6 4.2058845535459075e-5 -5.837099732426849e-5 -2.219803198651756e-5 0.00014261163025099588 5.955324897248035e-5 -4.4485581641240304e-5 0.0001368262253136578 7.752687971977347e-5 3.262573019772522e-5 -8.053173795171818e-5 -8.160695714006326e-5 -3.176684393376662e-5 -2.754926961908876e-5 -9.43178731930981e-5 7.558927988679338e-6 3.2978974301905324e-5 -6.949566730594447e-5 -0.0001126999980287482 0.0001518585574713736 3.86380940809307e-5 2.3502266883109507e-5 0.00012560119804863487 0.00011022459668044482; -8.854583266106641e-5 0.00014226531802528496 -4.4329554477325296e-5 -4.6211705075679885e-7 1.951598640561612e-6 -0.00024145176266933757 0.0002232727887513221 -9.01734862072322e-5 -5.694184856287893e-5 -7.310229248282594e-5 -3.889600757424935e-5 0.0001515997604638414 -1.3695470081027572e-5 7.088603469860559e-5 -2.4226769348342525e-5 4.4590720242969335e-5 0.00019993445621189468 0.00017078889086228615 -5.8516496757829025e-5 5.325574911516896e-5 -0.0001327129253713336 -3.329206138182562e-5 -0.00018647331726419974 5.892015851536649e-5 -9.517730050169947e-5 6.52441189213749e-5 -0.00011650088399796722 2.221365630071112e-5 -0.00013849745719376254 0.00010551377217602805 5.43563541196372e-5 -3.7521579180928615e-5; 4.4562229088566074e-5 7.194065714657012e-5 -0.00011054850161372094 -4.3036135100468694e-5 -2.8813891692611134e-5 0.00020735011370789127 -1.6608290895870364e-5 5.8052349270379626e-5 8.47430101589519e-5 7.818644102599052e-5 2.939891332288979e-5 9.21418950770177e-5 -1.2517023567396397e-5 -8.687204804801064e-5 5.925284953491221e-5 5.9822415134931806e-5 -9.498444974952079e-5 -3.118146101542435e-5 -6.830822072288529e-5 -0.0003533596272586722 -3.8125408150357053e-6 1.297331167961436e-5 -0.00012399625389435118 -4.218632507900248e-5 -6.739978103496531e-5 3.771992218617142e-5 -7.652890322560599e-5 -9.896246315975022e-5 8.550459919433322e-5 -1.8918254643268796e-5 1.546853404079178e-5 0.00011004755382828056; -0.00010544572459766777 9.6371811837887e-5 4.6981858491196964e-5 -2.956935013062447e-5 6.776965865865226e-5 -3.496661183308141e-5 -5.373561388696849e-5 -6.010081078275237e-5 -3.074898658669922e-5 -9.177374790308147e-6 4.724793662316965e-5 -8.2764512894211e-5 3.283691017347621e-6 6.478758560260608e-5 -0.00011725375979440454 -5.5389551830014614e-5 -0.0001082968449685737 4.172979943686947e-5 0.00016250951712718065 6.668863325612477e-5 0.00011225345925194977 8.448911076411427e-5 3.958152929553637e-5 1.503364156232964e-5 -7.699310689931937e-5 3.0105811406972242e-5 -4.415070388349961e-5 -5.6181186932374205e-5 1.631721871089996e-5 -9.844363643789662e-6 0.0001360895254998448 -0.00010387262618169398; -2.0863983565955583e-5 -7.75581743550677e-5 2.951302319636262e-6 -0.00015863937963558596 0.00013464184072972953 -1.825546181177411e-5 -8.656123330407026e-5 -1.7061384398854966e-5 -0.00017317419636350806 9.960250233008806e-5 -2.2247101271656573e-5 -1.3160312421397464e-5 -8.63761329423781e-5 2.8857720706134893e-5 1.8393576708614554e-5 5.239346107850944e-6 -0.00010236461324211048 4.524203674691975e-5 -0.00013020832787559504 -5.185425337230677e-5 -0.000216674600353589 1.5955405131093053e-5 -5.7541774851868585e-5 -6.0951004827515794e-5 -8.651665351179027e-5 7.872697230371115e-5 5.5684538408543116e-5 7.162266180560027e-5 -2.4282435817893322e-5 -0.00026177081102331056 2.37759881815872e-5 -8.164086696749897e-5; 3.1118342211527334e-5 0.00013975587271368317 -0.00015616692771920642 -5.914950898051012e-5 5.595222601067534e-5 2.6982335019668362e-5 -0.00010975101113338667 1.1705880841589327e-5 -8.29076315324884e-5 -2.9359877328893155e-6 1.7256367844936506e-5 -8.447306108911558e-5 -4.314083789843168e-5 5.155103199166255e-5 0.00011536911106184891 3.171379566872664e-5 0.00015653271757281443 0.00010072672657650913 3.9067045780650664e-5 0.00022734110016125336 -6.454734912160001e-5 7.923633447597473e-5 -2.3165722735136277e-6 4.872700001095091e-5 4.666709637484007e-5 0.00015616013033529713 -8.560757205607525e-5 -2.842410585919709e-5 -1.9249072375998107e-5 6.92317915928872e-7 0.00012272498051850523 -0.0001701121444978718; 0.0001120227139885766 -2.4428739738621372e-5 -0.0001484044072590483 -9.270138561490574e-5 -1.8410722436361332e-5 9.715537872168649e-5 -4.153696081131901e-5 6.404715223240112e-5 -8.180828753621981e-5 -7.276455011861049e-6 -0.00012527591333722307 -9.577443549546092e-8 6.560891602451999e-6 -0.00020238082233348337 -0.00011253882566464418 -4.669163349288033e-6 -1.1726130100700347e-5 -0.00013493158930444574 7.379943422591374e-6 0.00011043862887660509 5.1717056162913146e-5 -0.0002544234916008751 0.00014322856132408944 5.13059645577131e-5 -3.376584153968881e-5 -0.00011068012408443444 0.00011389697884212622 -5.8852492106350446e-5 5.5847002482063765e-5 0.00013980698404836324 4.372327803636879e-5 8.486713759836766e-5; 3.455099803541834e-5 4.163864188018644e-6 -0.00014264867778255688 -4.7968263887734444e-5 5.428682206112026e-5 -7.077367618766781e-5 1.6703783413690926e-5 0.00013282015433215593 5.803629764268758e-5 -7.098123743052647e-5 8.542290751502902e-5 0.00020800666979194635 -1.8122622267395703e-5 -0.00011891072782308771 -0.00015139228335901166 -6.7467207452804885e-6 -5.697860607032631e-5 9.468616019221041e-7 9.722240801348331e-5 -1.592296401933634e-6 9.940418393963013e-5 9.889609654347391e-5 -0.00015123459880559643 6.403563028580476e-5 6.734169445845236e-5 2.859297646064641e-5 -7.042673670075062e-5 6.0116417527030086e-5 3.734451828787169e-5 -2.4078011764497873e-5 -0.00015423153664304942 2.946662615723366e-5; -2.5306570858607246e-5 -0.00011136606273628999 -5.1605177685070526e-5 -3.8566548863104966e-5 -0.000118666212909534 9.138342722339826e-6 0.00015458415062477042 4.532745795752321e-5 -9.73617651818068e-5 5.8761069414522153e-5 -9.792206484978631e-5 0.00013895644140741095 5.921420514079371e-5 5.233242980632266e-5 0.00010086851314719918 8.192418234267966e-5 -2.2373150312138527e-5 -1.882562439111721e-5 5.107622481029579e-5 0.00025000140802314817 -0.00014789204662354498 4.316543261831716e-5 -6.011940505975428e-5 0.0001361871100760442 0.00012632523347003318 -9.269634840000785e-5 0.00014898132856780213 -2.210276845124335e-5 -9.469263106273059e-5 -6.543952705958836e-5 8.449255172474222e-5 -0.00011279498077618091; -3.0146839745556765e-5 9.223248039708545e-7 0.00010468687210873501 7.721153920925796e-5 3.209971433449188e-5 -3.7904402089181146e-5 -1.2781109799976321e-5 -3.3298059324249062e-6 -1.604936627787724e-5 -2.262508391226675e-5 -2.5048268836088822e-5 3.605033376430487e-5 5.148423329869737e-5 9.276562301152748e-5 8.60568990529443e-5 0.00011534821638547512 -8.57162366091326e-5 -4.3888411565549664e-5 -6.169017328006072e-5 -3.2631028117011655e-5 7.599165943161813e-5 9.546286415484866e-5 -5.444418533528574e-5 -5.34608614916068e-5 3.3250960003947005e-5 0.00016553232959181527 -9.266345918235952e-5 8.906205144653708e-5 7.583023503599338e-5 1.172908613676385e-6 1.9112779372384646e-5 -8.405665616096158e-5; -0.00021178277918371848 -7.073819531361622e-5 -4.724745329004811e-5 -5.787032594090853e-5 -5.5055596272010066e-5 -5.388214437990312e-5 -9.266872464214103e-5 8.971184772670693e-5 8.368856990709634e-5 -0.00010601215424250073 4.0148140874211353e-5 -8.98499459028376e-5 -0.0001613261813647201 3.4492793679223954e-5 0.00014446011093676694 0.00018236897060413802 3.3917050791192196e-5 6.32957050809502e-5 -5.801280374290709e-5 1.3153490476529308e-5 0.00011740919192087501 9.573067993622923e-5 -3.850949385670543e-5 -4.4877342839038616e-5 2.9048041801423065e-5 -4.469645161881059e-5 -5.923718733830264e-5 -7.072983523831736e-5 2.288090738582826e-5 1.9841765356238477e-5 1.8308033508936296e-5 1.806382078746934e-6; 0.00011318048051677037 5.268764504349539e-5 5.11692691047378e-5 -1.16982752247488e-5 -9.886975170460132e-5 -9.296203064381654e-5 -7.687555774696031e-5 -6.839112905666796e-5 -3.4180424003855404e-5 4.506011796783977e-6 0.0002519793691730814 3.0410259255065817e-6 -0.00012485797345520392 8.743876093418575e-5 0.00011556588866866318 -7.764758778152321e-5 -4.755930570766992e-5 0.00020977501696094184 6.5435348338234e-5 5.382915184369595e-5 -0.00010289797649833238 7.029621916941429e-5 4.562653190575006e-5 1.2886357494250278e-6 2.6712583681236903e-5 -7.340238665578345e-5 -2.6460198883293936e-6 -2.5994080383081883e-5 0.00015795494834079806 5.006607385936188e-5 -9.438784184980321e-5 -5.956048119927913e-5; -0.0001810978812480632 -0.00011951798860340216 -2.653371502478008e-5 0.0001598948636475661 0.0001574363612254792 -0.00023036647525972423 7.422474100363817e-5 4.8593081393456515e-6 5.171984030508983e-5 0.00015639735447817564 8.238278316782413e-6 1.3163770565513741e-5 1.2817056633285087e-5 -0.00011480433949858919 4.709247678161609e-5 -0.0001376399542985747 -4.460278632708873e-5 -9.356051504968706e-5 9.790992571848476e-5 5.674178992390578e-5 0.0002947704686468188 0.00025073619854279824 -0.0002467595570104486 -0.00010216933351133898 1.6507941656672447e-5 0.0001214212368742305 -0.00011390070194269282 -3.550455421449448e-5 -0.00018381962781663155 -6.989833691281205e-5 0.0001327107800554215 -0.00010260021572125022; -1.1497235572873234e-5 0.00011088370029059837 9.113346823453322e-6 -0.00012078080380856908 1.2340096691647475e-5 -2.3583613252957377e-6 2.3371749651579267e-6 8.90016142878688e-5 -9.654730292792459e-5 -9.181159311961591e-5 -5.323957720943669e-5 -0.0001850406863010492 0.00012699932067365897 0.0001309223715000784 4.924017939660607e-5 2.491374452833212e-5 3.506429293369877e-5 0.00011273902582647264 8.189237796645487e-5 -5.1270510266162333e-5 -0.00010597303680665875 0.00010442222792366387 2.6088991762977947e-5 4.002660880391324e-5 0.00014982407246893036 8.807685462701644e-5 6.836762422639402e-5 1.9088811132261775e-5 3.048434682600286e-5 -0.00012547863498153338 -0.00011197104522548861 -0.0001217799383081523; -0.00016848913792440106 -9.831051857128846e-5 -6.423770382918664e-5 8.545142345443785e-5 -3.5804767596196076e-5 -1.572706815005789e-5 -2.207933566915723e-5 -2.880492893549146e-5 -0.0002740239927350539 0.00011662751212331894 -0.0001144167675209555 7.687119885458238e-6 8.37668791957032e-5 -0.00014787901300390442 6.975020523729778e-5 -7.44053889337191e-5 -3.1991631023476614e-5 3.939246534916368e-5 5.394436082654078e-6 -3.877708723105781e-5 -4.6323233893262014e-5 -1.893648759671781e-5 -0.00017579930709128795 -2.1460526388024113e-5 3.2348872717353905e-5 -3.820396732573219e-5 -8.273520781417597e-6 0.000151828289470523 2.545713295891355e-5 -0.0001043826835501256 0.000130256108330447 2.3329371923116694e-5], bias = [1.0901971606496968e-9, 1.1195360298030545e-10, 3.486280104242549e-9, -1.2806817776345284e-9, -5.63621335310489e-9, 3.427378808976673e-9, -1.4053348983102016e-9, -1.610607445710005e-10, -1.988680402844395e-9, -3.200815872479825e-10, 4.051404187916165e-10, 4.0100084403260194e-10, 5.056354115062556e-10, -1.479766434641262e-9, 1.1274392231314167e-9, 2.0233213667339527e-9, -3.3331721236845113e-9, 4.141581563429113e-9, -4.1891503376380644e-10, -9.412883666573707e-10, 2.180109719327416e-10, -4.586156590596905e-9, 2.232899249338953e-9, -1.6183499026519697e-9, 3.7115679188420536e-10, 1.454356447438859e-9, 1.921906970679617e-9, -1.0589021404248926e-9, 1.431877048950527e-9, -5.76799578114796e-10, 9.293220014547897e-10, -2.988874107294579e-9]), layer_4 = (weight = [-0.0006325564968720554 -0.0007851738227025042 -0.000743083677281069 -0.0006717564550433921 -0.000727520303472048 -0.0008503022426730849 -0.0007191731029263355 -0.0008760296800976515 -0.0006325969870025125 -0.0006247136278008363 -0.0008090688789077634 -0.000766574383549693 -0.000657192542435103 -0.0008166116864312149 -0.0006195659943695264 -0.0008027445352815282 -0.0006729486778210866 -0.000563818493057049 -0.0007023878892369008 -0.0005733681317090004 -0.00043562230997569635 -0.0007539152894848845 -0.0005818070351457066 -0.0003290514545432851 -0.0005852779659237664 -0.000622938941543866 -0.0008629370624689347 -0.0008124396989353264 -0.0007942956437667562 -0.0007264597261666555 -0.0007026318910344304 -0.0007411081650725422; 0.00042004383822744166 0.00029354474565372856 0.00017296312311013023 0.00012679038003917404 0.0003145241660997308 0.0003931228915068941 4.7704560132733384e-5 8.292285657481812e-5 0.00021531051433560457 0.00019174667838282825 0.00032560116589643327 0.00011815991819667954 0.00025081685755024935 0.00029399421686405076 1.4275468980413641e-5 0.00012723100159547183 0.0002931230427831921 0.00011816140944993242 0.000290007872433805 0.0001744959180346551 0.00035927477154989806 0.000274595287270406 0.0002443454424613781 0.0002187679236793043 0.00018562029820348337 0.00037218054557826085 0.00016424882052414082 0.0002721037239144522 0.00035464106355030835 0.00028324753900728806 0.0002533021640389095 0.0002313583464226723], bias = [-0.0006955616883727599, 0.0002171220014284596]))Visualizing the Results
Let us now plot the loss over time
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")
lines!(ax, losses; linewidth=4, alpha=0.75)
scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)
fig
endFinally let us visualize the results
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
s3 = scatter!(
ax,
tsteps,
waveform_nn_trained;
marker=:circle,
alpha=0.5,
strokewidth=2,
markersize=12,
)
axislegend(
ax,
[[l1, s1], [l2, s2], [l3, s3]],
["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
position=:lb,
)
fig
endAppendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.12.4
Commit 01a2eadb047 (2026-01-06 16:56 UTC)
Build Info:
Official https://julialang.org release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 4 × AMD EPYC 7763 64-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-18.1.7 (ORCJIT, znver3)
GC: Built with stock GC
Threads: 4 default, 1 interactive, 4 GC (on 4 virtual cores)
Environment:
JULIA_DEBUG = Literate
LD_LIBRARY_PATH =
JULIA_NUM_THREADS = 4
JULIA_CPU_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0This page was generated using Literate.jl.