Training a Neural ODE to Model Gravitational Waveforms
This code is adapted from Astroinformatics/ScientificMachineLearning
The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl
Package Imports
using Lux,
ComponentArrays,
LineSearches,
OrdinaryDiffEqLowOrderRK,
Optimization,
OptimizationOptimJL,
Printf,
Random,
SciMLSensitivity
using CairoMakieDefine some Utility Functions
Tip
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector
function one2two(path, m₁, m₂)
M = m₁ + m₂
r₁ = m₂ / M .* path
r₂ = -m₁ / M .* path
return r₁, r₂
endNext we define a function to perform the change of variables:
@views function soln2orbit(soln, model_params=nothing)
@assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
if size(soln, 1) == 2
χ = soln[1, :]
ϕ = soln[2, :]
@assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
p, M, e = model_params
else
χ = soln[1, :]
ϕ = soln[2, :]
p = soln[3, :]
e = soln[4, :]
end
r = p ./ (1 .+ e .* cos.(χ))
x = r .* cos.(ϕ)
y = r .* sin.(ϕ)
orbit = vcat(x', y')
return orbit
endThis function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d_dt(v::AbstractVector, dt)
a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
b = (v[3:end] .- v[1:(end - 2)]) / 2
c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
return [a; b; c] / dt
endThis function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d2_dt2(v::AbstractVector, dt)
a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
return [a; b; c] / (dt^2)
endNow we define a function to compute the trace-free moment tensor from the orbit
function orbit2tensor(orbit, component, mass=1)
x = orbit[1, :]
y = orbit[2, :]
Ixx = x .^ 2
Iyy = y .^ 2
Ixy = x .* y
trace = Ixx .+ Iyy
if component[1] == 1 && component[2] == 1
tmp = Ixx .- trace ./ 3
elseif component[1] == 2 && component[2] == 2
tmp = Iyy .- trace ./ 3
else
tmp = Ixy
end
return mass .* tmp
end
function h_22_quadrupole_components(dt, orbit, component, mass=1)
mtensor = orbit2tensor(orbit, component, mass)
mtensor_ddot = d2_dt2(mtensor, dt)
return 2 * mtensor_ddot
end
function h_22_quadrupole(dt, orbit, mass=1)
h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
return h11, h12, h22
end
function h_22_strain_one_body(dt::T, orbit) where {T}
h11, h12, h22 = h_22_quadrupole(dt, orbit)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
h11 = h11_1 + h11_2
h12 = h12_1 + h12_2
h22 = h22_1 + h22_2
return h11, h12, h22
end
function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
# compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2
@assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"
h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
@assert mass_ratio ≤ 1 "mass_ratio must be <= 1"
@assert mass_ratio ≥ 0 "mass_ratio must be non-negative"
orbit = soln2orbit(soln, model_params)
if mass_ratio > 0
m₂ = inv(T(1) + mass_ratio)
m₁ = mass_ratio * m₂
orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
return waveform
endSimulating the True Model
RelativisticOrbitModel defines system of odes which describes motion of point like particle in Schwarzschild background, uses
where,
function RelativisticOrbitModel(u, (p, M, e), t)
χ, ϕ = u
numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
denom = sqrt((p - 2)^2 - 4 * e^2)
χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = numer / (M * (p^(3 / 2)) * denom)
return [χ̇, ϕ̇]
end
mass_ratio = 0.0 # test particle
u0 = Float64[π, 0.0] # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4) # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize) # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, eLet's simulate the true model and plot the results using OrdinaryDiffEq.jl
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)
axislegend(ax, [[l, s]], ["Waveform Data"])
fig
endDefining a Neural Network Model
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but in case you do use them, make sure to mark them as const.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,
const nn = Chain(
Base.Fix1(fast_activation, cos),
Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-4.9282564f-5; -0.000109693654; 3.0927479f-6; 2.922831f-5; -0.000121250865; -4.5968416f-5; 8.747186f-5; 6.2403975f-5; -3.865696f-5; -0.00015345462; 0.00014520032; -7.538964f-5; -2.192608f-5; -9.0463145f-5; -3.1848547f-5; 3.0665808f-5; 7.141167f-5; -6.147873f-5; -6.941701f-5; -4.3830823f-6; -8.787246f-5; 0.00020821765; -0.00021199611; -5.852132f-5; -4.055526f-5; -4.5009503f-5; 9.505672f-5; 3.893973f-5; -5.477132f-6; -5.5482415f-5; -6.233207f-5; -7.849689f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[2.5806135f-5 -3.735312f-5 -3.7854494f-5 -0.00012887854 -6.3899724f-6 -0.00018472379 8.675123f-5 1.3133483f-5 9.395791f-5 -0.00012672793 -8.9942296f-5 3.027246f-5 -1.2164543f-6 2.335706f-5 0.00015457127 7.5249554f-5 1.584129f-5 -2.7394113f-5 -0.00019730693 0.00017669519 -2.57457f-5 4.9132883f-5 6.878284f-5 -0.00021113387 7.673139f-5 7.514756f-5 -3.7184225f-5 -3.4501893f-5 6.177001f-5 0.00022261361 0.000118017335 0.00011736441; 5.438139f-6 -0.00012781868 -0.000104681094 -0.00010259643 -7.5229276f-5 5.3075637f-5 -1.877755f-5 4.3633594f-5 0.00020976979 0.00012230498 -0.00010076031 2.6514856f-5 8.656893f-5 0.00015361687 0.000100267076 -2.5182024f-5 0.000113166716 9.642463f-5 -5.657121f-5 -1.3062279f-5 1.1169262f-5 3.2282256f-5 -8.202364f-5 1.3687936f-5 6.200559f-5 -0.00021996791 -0.00033504886 -3.0256475f-5 -8.57882f-6 -7.149404f-5 -1.4677741f-5 -2.357259f-5; 8.337784f-5 -0.000103399565 0.00017574795 -3.9467213f-5 -3.9705217f-5 -0.00026156823 -0.00014051014 9.589222f-5 -7.91744f-5 -3.303876f-5 -0.00017897347 -0.00019523822 0.00011152806 -7.083984f-5 4.9173264f-5 6.362843f-5 -4.430126f-5 8.0745274f-5 3.2910917f-5 2.2115448f-5 0.00013908894 -4.1053823f-5 -1.2057044f-5 -4.2857002f-5 -9.632372f-5 4.870432f-5 0.00012692795 4.6875608f-5 -4.6419747f-5 4.581993f-6 -8.515942f-5 0.00010313408; -0.00011465587 5.3001007f-5 6.0017015f-7 3.076896f-5 -5.8515718f-5 -0.00020301295 -6.03033f-5 0.00010096577 -3.6210195f-6 -0.00014386624 1.055279f-6 2.3600594f-5 0.00018887526 -1.3441728f-5 0.00013057207 -3.663799f-5 -0.00010016247 2.1742774f-6 -4.4217228f-5 -0.000104764586 0.00012645136 -6.991365f-6 0.00020622282 -0.00010165935 9.8254146f-5 3.3564826f-5 -4.50458f-6 0.00010147077 -9.182548f-5 -3.6288853f-5 -0.00014555792 3.151f-6; 9.463901f-6 -0.00014655305 -7.463597f-5 0.00016760995 0.00020047012 -0.00014447024 0.00010713909 -0.00011046817 -9.514953f-5 -9.945307f-5 1.95125f-5 0.0001653626 -7.345629f-5 0.00016475022 -0.00017101665 -2.205294f-5 8.1834936f-5 -7.23488f-5 -9.0092704f-5 -0.00034598404 2.89091f-5 3.5087847f-5 0.00015261499 -2.306364f-5 0.0001335426 7.909336f-5 -0.00013603889 -6.7444715f-5 -6.826731f-5 -2.9167313f-5 2.928799f-5 4.9188453f-5; 1.3050421f-5 0.00014306043 5.8311474f-5 0.00011523002 -3.018279f-6 -6.991488f-5 4.364536f-5 3.913137f-5 0.00013654838 -0.00020184739 -8.1137165f-5 0.000110270434 -0.000113309805 2.2594664f-5 -3.0174686f-5 -0.00013968893 -8.81053f-5 -0.00018888661 0.00010648679 7.172592f-5 -3.2266013f-5 -7.444228f-5 7.9676785f-5 -6.896473f-5 0.00017538926 5.7204157f-5 0.00014960703 -2.5373593f-5 8.936283f-5 -2.2077676f-5 8.170856f-5 7.056667f-6; 8.8264824f-5 -0.0001546003 7.317961f-5 6.413035f-5 1.751099f-5 4.2677897f-5 -7.9316895f-5 -7.1989234f-5 1.5287056f-5 -0.00018800244 3.952776f-6 1.9595882f-5 -0.000101686186 0.00021845372 0.00010871243 3.9458046f-5 8.785853f-6 -0.00013868659 -2.7623878f-6 -6.3112698f-6 -5.7630637f-5 7.505934f-5 -3.4640663f-5 -5.966929f-5 0.00010946598 -7.753637f-5 0.0001632706 3.059794f-5 -5.754689f-6 -8.1925406f-5 -8.546617f-5 1.0603176f-5; -3.167212f-5 -4.0902527f-5 8.9868124f-8 -0.00012638899 7.4503405f-5 -7.327959f-5 9.8196026f-5 -0.00011659593 9.418818f-5 9.062436f-5 1.8125068f-6 4.089393f-5 -0.00014989694 -6.8567955f-5 -5.2686424f-5 0.00021274376 -0.00017857854 -1.932327f-5 -0.0001476757 1.8741176f-6 1.9231184f-5 8.651719f-5 8.06664f-5 -1.2419971f-5 2.2947355f-5 -8.997318f-5 0.00013932164 1.4686426f-5 4.2772932f-5 0.00010365515 6.948265f-5 5.4697965f-5; -0.00015864724 4.0479634f-5 -5.9801365f-5 2.386307f-5 0.00018041082 9.782443f-5 0.00015148983 1.8130586f-5 4.2058786f-5 -0.00014959996 -8.4961895f-5 -0.00013067295 -0.00011674324 7.003072f-5 7.3174946f-5 -1.5248738f-5 6.381907f-5 7.684193f-5 0.00023777272 -0.00015464948 -2.988308f-5 0.00014370348 0.00012019266 0.0001174517 1.0285784f-6 -2.9011964f-5 4.405283f-5 0.00015485725 -9.434312f-5 0.0002628787 -6.9094713f-6 8.154739f-5; 1.1854937f-5 0.00016815758 -6.2941444f-6 2.7025302f-5 -6.146842f-5 -1.5029282f-5 8.663521f-5 5.337134f-5 -1.0592358f-5 -7.3380255f-7 -1.34793f-5 7.977251f-6 3.3297834f-5 -6.3920095f-5 -9.263386f-5 -6.5693865f-5 7.640644f-5 6.565706f-5 0.00014663665 1.8061226f-5 2.512139f-5 1.5316593f-5 2.1231108f-5 0.00011000008 -6.67806f-6 0.00013496973 8.5800275f-5 9.7738404f-5 -0.00013744632 -5.0162434f-5 1.3922261f-5 7.5369567f-6; -1.3019379f-5 3.395085f-5 0.00021910688 -6.5374225f-5 3.221765f-5 -8.6329266f-5 -9.233109f-5 -5.657242f-5 -2.6831258f-5 0.00016322506 -0.00017721226 3.5573143f-5 3.9865114f-5 -3.095079f-5 -8.939402f-5 5.5560256f-5 8.346617f-5 4.1288065f-5 0.00011721773 1.5639465f-5 -3.0693413f-5 7.931924f-5 -8.964528f-5 1.3610089f-5 -9.767671f-5 -6.681395f-5 3.5075962f-5 -6.2945255f-5 -4.157036f-5 -3.9845025f-5 -0.000100375895 -0.0001720582; -3.8649974f-5 -8.508749f-5 0.00014788046 4.1295727f-5 1.8213765f-5 2.867241f-5 6.889007f-5 4.0631945f-5 -5.8116126f-5 -5.749797f-5 8.752659f-5 -2.5410684f-6 4.7315738f-5 -5.4580516f-5 1.335048f-6 0.00011119467 9.836357f-5 -8.725265f-5 -1.4486081f-6 0.00021702293 0.00016755333 0.0001905991 0.00011346866 0.00012917175 -5.518999f-5 1.8717488f-5 5.759894f-5 -9.1000635f-5 9.8902805f-5 4.55607f-5 3.464542f-5 -5.5468034f-5; -5.4251097f-5 -1.2284145f-5 4.3211076f-5 -0.00019933943 4.5720855f-5 3.5597954f-5 1.3626285f-5 -0.00015958294 6.426934f-5 -8.711602f-5 7.9915975f-5 6.358742f-6 5.733272f-5 -0.00010634909 8.9783156f-5 -3.7632595f-5 1.6483331f-5 7.2029645f-5 -5.2637643f-5 -9.09215f-5 5.5236207f-5 -0.000110630375 -9.681765f-5 4.598413f-5 -2.5301504f-5 -0.00011898568 0.00014972301 -0.00012232219 9.285801f-5 -9.493964f-5 7.6604316f-5 1.0267234f-5; -1.5566524f-5 1.1101052f-5 9.7791555f-5 4.147366f-5 -6.243845f-6 -6.777311f-5 4.242734f-6 2.9123461f-5 0.00019010482 0.00012911817 -9.3467774f-5 0.000108952154 -8.414964f-5 -6.8782705f-5 0.00011475995 2.5773454f-5 2.7165244f-5 -0.00013485593 -6.620986f-5 -2.002423f-5 5.3454598f-5 -8.468515f-5 1.8500485f-5 0.00024531956 -6.3281834f-5 3.450223f-5 9.0645226f-5 -8.747954f-6 -0.00016887081 -1.2990451f-5 -9.370403f-5 0.00013798822; -0.00018991494 -0.00015737418 3.3868164f-5 -0.00010525022 0.00011046596 8.556909f-5 -8.938593f-5 -8.3208484f-5 -4.159312f-5 0.0001096993 -5.558764f-5 -5.003108f-5 0.0001284504 0.00022979715 -7.554834f-5 3.0878677f-5 -0.00013502337 -1.4033781f-6 8.2673396f-5 7.2858835f-5 -4.9363607f-5 -0.00012236858 8.822558f-5 7.237206f-5 -0.00013452633 6.767005f-5 -4.1309966f-5 9.496832f-5 0.00011224289 0.00013009914 5.915031f-5 9.0113565f-5; 0.00026121648 6.5188535f-5 0.00017564486 -0.00012135976 0.00018865503 -1.7915972f-5 -4.1152067f-5 2.5130055f-5 -0.0001420224 0.00021135267 -4.6040368f-5 -0.00014780131 -0.00012188868 0.00010873748 4.5838526f-5 0.00011652405 8.9976296f-5 0.00010437495 4.1079395f-5 -0.000121070916 -6.304918f-5 2.8696293f-5 5.7010453f-5 3.8423797f-5 0.00014742743 0.000100480094 -0.00011573316 -6.5638684f-5 -7.500443f-6 0.00016678567 -7.306444f-5 0.00029076767; -0.00011219628 -6.7671645f-5 1.9882704f-5 -2.3596156f-5 -0.00011968754 0.000112826594 -4.1008476f-5 3.0786065f-5 3.1443928f-5 -9.6771815f-5 9.0340895f-5 9.4211806f-5 -0.00020543241 3.8864535f-5 9.146544f-5 -6.7944966f-5 -7.771511f-5 -6.507824f-5 -4.0570296f-5 4.596269f-5 -4.255951f-5 -0.00011892924 -2.7054699f-5 3.2363798f-5 2.4558181f-5 8.503015f-5 7.1945185f-5 0.00017199987 3.2733955f-5 -0.00015094876 2.1896452f-5 -0.00015361953; -3.5766017f-5 2.482854f-5 9.6351585f-5 -0.00016278576 -0.000254619 0.00026739825 -0.0001326547 -7.7475615f-6 5.376352f-5 -0.00012246886 -0.0001568909 3.103827f-5 1.1600652f-5 -9.339471f-5 0.00011774547 7.668079f-5 -0.00010951314 -0.00015888356 -5.0904557f-5 4.2414176f-5 0.00011685622 -3.2655247f-5 -5.824625f-5 -0.00027396998 8.0549165f-5 -5.6434783f-5 4.8815753f-5 -1.3432064f-5 3.6922283f-5 0.00013089927 -3.9499646f-6 7.775135f-5; -9.396219f-5 -2.4433139f-5 -8.512754f-5 5.127575f-5 -4.008734f-5 0.00010497016 -1.5076987f-5 -6.8812275f-5 -1.9794723f-5 -6.239151f-5 -8.096549f-6 0.00016096847 1.4725754f-5 -0.00015841419 -0.00010483959 5.3024105f-5 -7.6842945f-5 -0.00011985375 -3.1468055f-5 -9.6007825f-5 -8.429915f-5 -2.1335f-5 3.605683f-5 1.3537123f-5 0.00016961033 0.00010840855 3.237942f-5 0.00018567992 6.847802f-5 -3.3170967f-5 -4.6417932f-5 -0.00014260923; -6.283749f-5 0.00010538332 7.887251f-5 -0.0001032266 6.671471f-5 -3.4629335f-5 9.756252f-5 -5.1018724f-5 0.000108030705 3.324836f-5 -9.3410155f-5 -4.078392f-5 -0.00014906302 0.00012348956 4.936971f-5 -4.66617f-5 -4.597393f-5 0.00011922113 0.00012237835 -0.00011458639 0.00016073685 -7.15253f-5 7.534935f-5 -8.731255f-5 -5.464463f-5 5.1187253f-6 -9.5874285f-8 0.00020734778 -0.000159643 -2.1270458f-5 1.4696318f-5 -5.2574083f-5; -6.503506f-5 -3.15683f-5 -3.3301133f-6 -7.377534f-5 -5.3350886f-5 -0.00010747643 -1.751645f-7 -1.4072625f-5 0.00017440421 3.2707507f-5 -6.155247f-5 7.0955844f-5 3.1567008f-5 8.505364f-5 0.00026525877 4.6072575f-5 3.5616104f-5 9.296299f-5 -0.00019652775 0.00013233098 -2.67286f-5 6.179115f-5 -6.109173f-5 -6.280076f-5 3.9041228f-5 3.730208f-5 8.0720565f-5 2.091457f-5 -3.4773828f-5 -0.00014179673 0.00022140227 0.00015520657; -0.00012261668 6.144284f-5 1.9916395f-5 -4.8089372f-5 8.013556f-5 -0.00010046366 4.72961f-5 0.0002545384 3.3616754f-5 -0.00014271909 -2.4481353f-5 -2.4332445f-5 9.6790645f-5 -0.00014092641 -2.8430664f-5 -6.6719504f-6 -1.2370291f-6 -5.7420213f-5 -0.000105769184 2.3234661f-5 0.00014536713 -8.0727186f-5 -3.414909f-5 -0.00013207315 0.0001559402 -5.558988f-5 -2.9654928f-5 7.429592f-5 3.52141f-5 -1.5086417f-5 1.6836835f-6 0.00017133476; 6.7665955f-5 2.1919834f-5 -4.5093493f-5 3.339933f-5 -6.909665f-5 -9.00872f-5 -0.00022932945 6.2976396f-5 5.070958f-5 0.0001062187 -5.9601905f-5 -7.2449344f-5 -2.2596621f-5 1.3636027f-5 -0.000105399304 3.005837f-5 -9.320243f-6 -3.088103f-5 9.2516675f-6 -0.00010589325 -4.944296f-5 9.347078f-5 2.5939908f-6 9.4733754f-5 7.8946716f-5 1.9095276f-5 5.6393386f-5 1.3619953f-5 0.000105569816 -5.960945f-5 2.7841877f-5 3.4454803f-5; -8.281305f-5 -0.00011604438 7.2902745f-5 4.801678f-5 7.3743926f-5 4.9847178f-5 -4.4345717f-5 -3.5937246f-5 0.00012620617 -0.00021797595 -0.0001366687 0.00011570417 9.180728f-6 7.6599914f-5 -6.1790574f-6 -1.10808805f-5 2.6453477f-5 -5.7165275f-6 0.00019925507 7.484596f-7 0.00015903689 -7.3081435f-7 -0.00010691647 0.00012318043 1.832157f-6 -2.5698273f-5 2.6579568f-5 -2.0122985f-5 -0.00014835813 -1.9076728f-5 0.00010018829 -3.7792983f-5; 3.0382169f-5 4.1700572f-5 -3.9765633f-5 -0.00019718862 -7.877913f-5 -3.1170417f-5 -6.0214992f-5 -0.00010063576 2.4647146f-5 0.00013540403 4.382023f-5 -5.851472f-5 9.889696f-5 4.016394f-5 3.547536f-5 0.00014761034 -1.724408f-5 2.4127248f-5 -6.446331f-5 -3.916224f-5 -1.9443662f-5 2.0509617f-5 -0.00013213213 0.00014921672 3.8748727f-5 4.6177553f-5 -1.6366497f-5 -6.3737054f-5 3.686359f-5 -0.00015622821 0.00012862275 -6.921182f-5; 0.000114265444 -1.1883704f-5 -5.744089f-5 -5.154866f-5 -1.983519f-5 3.3239437f-5 3.878419f-6 -1.262121f-5 4.8768554f-5 4.5148376f-5 0.00024443376 0.00013096957 -1.9198082f-5 -4.7895046f-5 5.2231658f-6 -4.689349f-5 5.86966f-5 6.808497f-5 0.00011012079 -0.00013128083 -3.673182f-5 -1.749969f-5 1.7142123f-5 -7.418296f-5 -1.9764142f-5 7.446894f-5 5.0729454f-5 -0.00014780693 5.3091215f-5 0.00023371942 -7.200241f-5 9.0981885f-5; -0.00011346398 -6.9489826f-5 3.8570375f-5 -2.4929326f-5 5.9469883f-5 0.00011605488 0.00016228117 -0.00020376673 1.343017f-7 4.2560578f-6 -5.7665275f-6 -8.036949f-5 -9.655164f-5 1.2821008f-5 -0.00015810157 0.000101481695 -5.5981196f-5 -2.284651f-5 2.425223f-5 -0.000252979 -6.848923f-6 9.625262f-6 -0.00018943928 -3.870842f-5 1.7423486f-5 -0.00013594133 -5.8177397f-5 -0.00013969043 -1.6514152f-5 9.978428f-5 -0.00016986395 -5.3067874f-5; 0.0001665911 2.5525434f-5 -4.5623197f-5 4.9515816f-6 1.550324f-5 6.894129f-5 3.452552f-5 -0.00018490308 7.067725f-5 7.809077f-5 0.00016328323 3.164728f-5 -5.6382552f-5 -1.810782f-5 2.746304f-5 -5.7707544f-5 0.00011514487 -5.9325284f-5 1.4740512f-5 5.4909775f-5 -0.000110762856 0.00014327082 5.1559193f-5 -5.946735f-5 -4.483298f-6 -0.00013497974 2.142881f-5 0.00016063839 -0.00013517466 -1.2970531f-5 -0.00012878249 -7.321053f-5; -5.0428025f-5 0.0001275517 0.00017664664 2.707455f-5 3.1062966f-5 -7.741571f-5 2.6356825f-5 -8.9195426f-5 -0.00017936835 0.00019517387 -0.00011380253 3.0019564f-6 0.00017422267 0.00013984056 -0.00011179657 -0.000101203645 4.170133f-6 -6.724989f-6 0.00011610817 0.00014626604 1.2365589f-5 0.00010345768 -0.00016086492 -0.00017058886 -0.000110231245 3.6978654f-5 8.512607f-5 -1.221186f-5 4.6732745f-5 -0.00017336414 -0.00014730867 0.00012732086; -7.243023f-5 7.713994f-5 8.269389f-5 -0.00014968528 0.00014364562 -0.00023658435 3.591493f-5 -5.278864f-5 0.00015175802 2.2466504f-5 5.958677f-5 8.46138f-5 9.848894f-5 -0.00019269512 -0.00012509036 -0.00012281722 -0.00010780022 7.192105f-5 0.00017242422 0.00022465906 8.195981f-5 -7.6699995f-5 -4.82001f-5 -0.00022904178 -9.887759f-5 0.00016672564 -4.84806f-5 1.1755503f-5 1.9051758f-5 -3.7814614f-5 6.400119f-5 -0.00016747348; -1.5716152f-5 -3.0296329f-5 -5.3891246f-5 -5.858896f-5 0.00023402784 0.00017317798 1.6919712f-5 -7.732683f-5 -0.00010154598 -0.00024931156 -1.6593866f-6 -5.8665253f-5 -0.00015227844 -8.100386f-6 4.5410925f-5 -3.409897f-6 -0.00015194099 8.76228f-5 -0.000103106904 -0.00011814534 -2.7873239f-5 6.842156f-5 4.2768595f-5 -5.840428f-6 -0.00013538013 6.7161054f-5 7.2800445f-5 7.2121635f-5 -5.9609378f-5 -0.00018192687 6.413769f-5 0.00014921244; -6.717744f-5 7.4163165f-7 -9.686193f-5 3.2858607f-5 -0.00012890216 -6.7084766f-5 0.00014147966 -0.00013057032 -3.343709f-5 5.1315315f-6 3.5511785f-5 9.9448f-6 -6.393821f-5 1.1978233f-5 -2.1999776f-5 -2.3117713f-5 9.541903f-5 -2.5277683f-5 5.7326266f-5 -2.7146782f-5 2.025066f-5 1.1269094f-5 3.264538f-5 5.1656374f-5 -0.0002182602 -5.5665943f-5 6.775741f-5 6.0913713f-5 -0.000101142956 6.538212f-5 1.8914287f-5 5.417224f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[1.9034505f-5 1.0135475f-5 -0.00012230256 -2.399417f-5 -7.1338625f-5 0.00013585927 -3.1376545f-5 -0.00011818849 4.9393213f-5 -7.423753f-5 5.586855f-5 -9.928022f-5 0.00017649912 1.1414396f-5 -9.926556f-5 0.00014426204 -2.919505f-5 -1.8633582f-6 -3.5027013f-5 8.2004124f-5 0.00020699727 -4.741099f-5 -4.0361523f-5 6.2491745f-5 -8.880153f-6 -4.934867f-5 -0.000117517484 8.6540436f-5 -1.654458f-5 4.9684822f-5 0.00015449441 -9.408846f-5; -5.202873f-5 9.482842f-5 -0.00018769778 -6.6652516f-5 -0.00012450248 -1.3469998f-5 -2.6208778f-5 1.3096502f-5 6.23209f-5 3.202077f-5 7.0842696f-5 -1.9502086f-5 -5.8961414f-5 0.00021766659 9.042021f-5 -4.9865343f-5 1.02732565f-5 -1.8280272f-5 9.1870665f-5 1.0110035f-5 1.709567f-5 -0.00014660347 0.00015319741 0.000104745886 7.095246f-5 -0.00012050364 5.418925f-5 4.5106033f-5 3.0274761f-5 -0.0001746005 -0.0001046698 4.8466853f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))Similar to most deep learning frameworks, Lux defaults to using Float32. However, in this case we need Float64
const params = ComponentArray(f64(ps))
const nn_model = StatefulLuxLayer(nn, nothing, st)StatefulLuxLayer{Val{true}()}(
Chain(
layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
layer_2 = Dense(1 => 32, cos), # 64 parameters
layer_3 = Dense(32 => 32, cos), # 1_056 parameters
layer_4 = Dense(32 => 2), # 66 parameters
),
) # Total: 1_186 parameters,
# plus 0 states.Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses
where,
function ODE_model(u, nn_params, t)
χ, ϕ = u
p, M, e = ode_model_params
# In this example we know that `st` is am empty NamedTuple hence we can safely ignore
# it, however, in general, we should use `st` to store the state of the neural network.
y = 1 .+ nn_model([first(u)], nn_params)
numer = (1 + e * cos(χ))^2
denom = M * (p^(3 / 2))
χ̇ = (numer / denom) * y[1]
ϕ̇ = (numer / denom) * y[2]
return [χ̇, ϕ̇]
endLet us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
axislegend(
ax,
[[l1, s1], [l2, s2]],
["Waveform Data", "Waveform Neural Net (Untrained)"];
position=:lb,
)
fig
endSetting Up for Training the Neural Network
Next, we define the objective (loss) function to be minimized when training the neural differential equations.
const mseloss = MSELoss()
function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
return mseloss(pred_waveform, waveform)
endWarmup the loss function
loss(params)0.0007265651980571294Now let us define a callback function to store the loss over time
const losses = Float64[]
function callback(θ, l)
push!(losses, l)
@printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
return false
endTraining the Neural Network
Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
optprob,
BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
callback,
maxiters=1000,
)retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-4.928256385023518e-5; -0.00010969365393966287; 3.0927478746870596e-6; 2.9228309358580215e-5; -0.00012125086504954471; -4.596841608875763e-5; 8.747186075192957e-5; 6.240397487991794e-5; -3.8656959077331635e-5; -0.0001534546172478; 0.00014520031982098893; -7.538963836852299e-5; -2.1926080080452987e-5; -9.046314517029632e-5; -3.184854722345434e-5; 3.066580757148993e-5; 7.141166861392068e-5; -6.147872772996248e-5; -6.941700848976608e-5; -4.3830823415126224e-6; -8.787246042629558e-5; 0.00020821765065169768; -0.00021199611364827611; -5.8521320170182634e-5; -4.0555260056845985e-5; -4.5009503082782066e-5; 9.505671914662285e-5; 3.8939728256073014e-5; -5.477132162924241e-6; -5.548241460918648e-5; -6.233206659084613e-5; -7.84968869991938e-5;;], bias = [-5.4759473323549366e-17, -1.0662924020450199e-16, 2.09772773660647e-18, 7.063324249170184e-18, -1.28652333596944e-16, 3.449330228535122e-17, 5.13510454724978e-17, 9.116928981509348e-17, -1.286650133108435e-17, -1.3204183647981045e-16, 3.28684740973699e-17, 1.851129112454875e-17, 5.685751927472079e-18, -7.141683840526145e-17, -3.949464971824175e-17, 5.646088066395563e-18, 1.1070914770873802e-16, -3.2088635909915975e-17, -7.872631925154267e-17, -7.462265326130271e-18, -1.6586356933419672e-18, 2.6465230839923004e-16, -4.0781819454218716e-16, -8.4119779478292e-17, -1.8455008436023148e-17, -6.2696973092581e-17, 5.1696130474722516e-17, 9.804038362028046e-18, 1.0296042775151762e-18, -1.1320582130855967e-16, -5.814667550239383e-17, -1.7019971726447573e-16]), layer_3 = (weight = [2.580750996423255e-5 -3.7351743876921346e-5 -3.785311920217296e-5 -0.00012887716794501153 -6.388597931061064e-6 -0.00018472241273264552 8.675260259189248e-5 1.3134857262792097e-5 9.395928581363398e-5 -0.0001267265549453365 -8.994092123465675e-5 3.0273834717824754e-5 -1.2150797696270977e-6 2.3358435364221673e-5 0.0001545726475247824 7.52509288664618e-5 1.5842664006065537e-5 -2.7392738693726635e-5 -0.00019730555569402215 0.000176696564530739 -2.5744325021500043e-5 4.913425736012746e-5 6.878421781002349e-5 -0.00021113249450571114 7.673276494608131e-5 7.51489344926257e-5 -3.718285071080078e-5 -3.450051800462541e-5 6.177138792701938e-5 0.0002226149875146969 0.00011801870940836848 0.0001173657867999435; 5.437012876095536e-6 -0.00012781980887881334 -0.00010468222000996878 -0.0001025975562900983 -7.523040232841613e-5 5.307451090414793e-5 -1.8778675550823224e-5 4.363246796641348e-5 0.00020976866096828828 0.00012230385061924923 -0.00010076143928232123 2.6513729496979237e-5 8.656780289823844e-5 0.00015361574503698397 0.0001002659499841996 -2.5183149884802082e-5 0.00011316558982583588 9.642350221623513e-5 -5.6572336800276544e-5 -1.306340538791673e-5 1.116813553165867e-5 3.2281130077218546e-5 -8.202476430810564e-5 1.3686810123213288e-5 6.20044610310602e-5 -0.00021996904081094296 -0.00033504998111526447 -3.025760099477838e-5 -8.579945754141672e-6 -7.149516579589986e-5 -1.4678867157603725e-5 -2.357371533650041e-5; 8.337654042120622e-5 -0.00010340086403006686 0.00017574664859862493 -3.946851240757392e-5 -3.9706516257085633e-5 -0.0002615695305604715 -0.00014051143473183924 9.589091830717057e-5 -7.917570139684227e-5 -3.3040058336281005e-5 -0.00017897476602514822 -0.00019523951443547655 0.00011152676204515136 -7.084113560553804e-5 4.9171965362667125e-5 6.362713218263815e-5 -4.430256042100781e-5 8.074397526511576e-5 3.290961801835963e-5 2.2114148738956612e-5 0.0001390876383783091 -4.105512227682008e-5 -1.2058342984463193e-5 -4.2858301024487145e-5 -9.632501707815071e-5 4.870302261847581e-5 0.0001269266464777092 4.6874308897578386e-5 -4.6421046429849625e-5 4.580693873187719e-6 -8.516072223132819e-5 0.00010313278078397491; -0.00011465655382145652 5.300032523572204e-5 5.994879166639486e-7 3.07682765856802e-5 -5.851639991980229e-5 -0.00020301363668060387 -6.030398265230808e-5 0.00010096508947079578 -3.62170169629864e-6 -0.00014386692612553006 1.0545967348556096e-6 2.3599912194717584e-5 0.00018887457724539063 -1.3442410024631307e-5 0.00013057138300844916 -3.663867221568266e-5 -0.00010016315184422233 2.1735951327938325e-6 -4.421790993311844e-5 -0.00010476526779477062 0.0001264506735174883 -6.992047135469183e-6 0.00020622213366786504 -0.00010166003460418867 9.825346375926856e-5 3.3564143334464204e-5 -4.505262156648845e-6 0.00010147008458496527 -9.18261631590365e-5 -3.62895353895656e-5 -0.00014555860082274722 3.1503178639329393e-6; 9.462514160227578e-6 -0.00014655443797239963 -7.463735951463234e-5 0.000167608562466889 0.00020046873700918817 -0.00014447162234576798 0.00010713770038561767 -0.00011046955901550429 -9.515091792567107e-5 -9.945445767955269e-5 1.9511112757003953e-5 0.00016536120833476786 -7.345767940281412e-5 0.00016474883463812886 -0.00017101804101260807 -2.205432670310571e-5 8.183354961799215e-5 -7.235018406648492e-5 -9.009409100402403e-5 -0.00034598542393072054 2.8907713701453655e-5 3.5086460545413054e-5 0.0001526135996274991 -2.3065026337312004e-5 0.00013354121924491198 7.909197606492405e-5 -0.00013604027176627954 -6.744610132303493e-5 -6.826869926302234e-5 -2.916869985059239e-5 2.9286603737242322e-5 4.918706631284942e-5; 1.3051911378914305e-5 0.00014306192163149998 5.8312964325359135e-5 0.00011523150744853317 -3.0167888546096275e-6 -6.991339039285019e-5 4.3646849304255194e-5 3.913285975233283e-5 0.00013654986680723474 -0.00020184589964100544 -8.113567462186593e-5 0.00011027192370039889 -0.00011330831482535431 2.2596154236284855e-5 -3.0173195840853704e-5 -0.00013968743739877122 -8.810380829725841e-5 -0.00018888511992752916 0.00010648828022187164 7.172740727202786e-5 -3.2264522518912275e-5 -7.444078953467711e-5 7.967827543852369e-5 -6.896323763992984e-5 0.00017539074754114933 5.720564724999201e-5 0.0001496085227025238 -2.5372102458442308e-5 8.936432125285085e-5 -2.2076186427677552e-5 8.171005022251612e-5 7.058157121551802e-6; 8.826459495778202e-5 -0.00015460053284457808 7.317938114012579e-5 6.413011986025604e-5 1.7510762042567154e-5 4.2677668240703655e-5 -7.931712345723951e-5 -7.198946288824997e-5 1.528682739901353e-5 -0.00018800266436572025 3.952547082937351e-6 1.9595653136112073e-5 -0.00010168641485777694 0.00021845348644730682 0.00010871220408900637 3.945781688982614e-5 8.785624665525049e-6 -0.0001386868202853108 -2.7626165528814474e-6 -6.311498530503735e-6 -5.7630866215712896e-5 7.505911023783805e-5 -3.464089197493769e-5 -5.9669520417610076e-5 0.00010946575319123454 -7.753660204160025e-5 0.00016327036835447938 3.059771231803109e-5 -5.75491778634386e-6 -8.192563429745835e-5 -8.54663990348669e-5 1.0602946965282635e-5; -3.167156370801739e-5 -4.090197086663048e-5 9.042421885350888e-8 -0.0001263884301504851 7.4503960998467e-5 -7.327903273727627e-5 9.8196581742e-5 -0.00011659537682754834 9.418873690555001e-5 9.062491465845805e-5 1.8130628479255322e-6 4.089448651319423e-5 -0.0001498963798138112 -6.856739907272153e-5 -5.268586807397989e-5 0.00021274431804448507 -0.0001785779864501947 -1.932271385777654e-5 -0.00014767514636952362 1.8746737243300163e-6 1.9231739709708064e-5 8.651774572461212e-5 8.066695903214896e-5 -1.2419415331181596e-5 2.294791141429111e-5 -8.9972626089378e-5 0.00013932219320230178 1.4686981952024048e-5 4.2773488015142896e-5 0.0001036557036360922 6.948320827467452e-5 5.4698520610314316e-5; -0.00015864330102744926 4.048356809247511e-5 -5.979743050233563e-5 2.3867003484659603e-5 0.00018041475218663806 9.782836227855275e-5 0.0001514937665447578 1.8134520225020678e-5 4.206272010519809e-5 -0.00014959602608403502 -8.495796114264054e-5 -0.000130669019334576 -0.00011673930762157937 7.003465530173639e-5 7.317888034918116e-5 -1.5244803897482211e-5 6.382300293948682e-5 7.684586112331949e-5 0.00023777665578704544 -0.0001546455406682483 -2.9879145721184698e-5 0.00014370741505580908 0.0001201965913892233 0.00011745563629260681 1.0325127064356237e-6 -2.9008029876829913e-5 4.405676541095278e-5 0.00015486118159201636 -9.433918798863061e-5 0.000262882624021177 -6.905536928071078e-6 8.155132477581546e-5; 1.1857648061109527e-5 0.00016816029248844293 -6.291433309099542e-6 2.702801332022045e-5 -6.146570656596528e-5 -1.5026570580388785e-5 8.663792406992834e-5 5.337405029537335e-5 -1.0589647062018798e-5 -7.310914190884826e-7 -1.3476588791265976e-5 7.979961920920964e-6 3.3300545438821573e-5 -6.391738413980431e-5 -9.263115050883531e-5 -6.569115346677162e-5 7.640915002526786e-5 6.56597739373539e-5 0.0001466393631938109 1.8063937133108818e-5 2.512410075838518e-5 1.531930403982708e-5 2.1233819026041762e-5 0.0001100027920267753 -6.6753487578660234e-6 0.00013497243791904648 8.580298610582673e-5 9.774111542471141e-5 -0.00013744360610756502 -5.0159723358533634e-5 1.392497226448909e-5 7.539667805087304e-6; -1.3020897620002985e-5 3.394933324919391e-5 0.00021910535966774167 -6.537574275832865e-5 3.22161310238994e-5 -8.633078444952326e-5 -9.233260547013864e-5 -5.657393679691571e-5 -2.6832776108926403e-5 0.00016322354591715094 -0.00017721378147804353 3.55716244145948e-5 3.986359539587721e-5 -3.095230689475194e-5 -8.939554148789731e-5 5.5558738188415247e-5 8.346465134557303e-5 4.128654715052707e-5 0.00011721621073113706 1.5637947092477194e-5 -3.0694930808086845e-5 7.931771930338804e-5 -8.964679485620863e-5 1.3608570646034797e-5 -9.76782297014093e-5 -6.681546559527702e-5 3.507444367890407e-5 -6.29467734481974e-5 -4.157187989883407e-5 -3.984654298371768e-5 -0.00010037741352541467 -0.0001720597113828918; -3.8645302985105206e-5 -8.508282188910542e-5 0.00014788513363687736 4.1300398160231355e-5 1.8218435806311158e-5 2.8677080998887315e-5 6.889474026042221e-5 4.063661618506359e-5 -5.811145455577222e-5 -5.749330011084117e-5 8.753126167804937e-5 -2.536397224553346e-6 4.7320409074885464e-5 -5.457584483434343e-5 1.3397191277841838e-6 0.00011119934061657908 9.836824029566361e-5 -8.724797952804873e-5 -1.4439369725130124e-6 0.00021702759841805738 0.00016755799880133183 0.0001906037650850356 0.00011347333202946099 0.0001291764201677766 -5.518531906186438e-5 1.872215944690737e-5 5.7603610920468074e-5 -9.099596353830584e-5 9.890747606637367e-5 4.556537170078553e-5 3.465009286532935e-5 -5.546336250613966e-5; -5.425282310177938e-5 -1.228587083467115e-5 4.320934946001007e-5 -0.00019934115138319774 4.571912919339108e-5 3.559622760796535e-5 1.3624559105594347e-5 -0.00015958466414311254 6.426761516187854e-5 -8.711774884438268e-5 7.991424914501345e-5 6.357016136390633e-6 5.733099365890483e-5 -0.00010635081875086807 8.978142992411295e-5 -3.7634321270235654e-5 1.648160483448083e-5 7.20279187225242e-5 -5.2639368586896466e-5 -9.092322747172356e-5 5.5234481042064004e-5 -0.00011063210135039831 -9.681937248392831e-5 4.5982402443696425e-5 -2.5303230213540974e-5 -0.00011898740362092309 0.00014972128534275107 -0.00012232391767618661 9.28562859916544e-5 -9.494136778819704e-5 7.660258976095072e-5 1.0265507771620338e-5; -1.556504536635206e-5 1.1102530727684885e-5 9.779303366743648e-5 4.147513903520649e-5 -6.242366332272972e-6 -6.777162922586729e-5 4.244212625810351e-6 2.9124939494826486e-5 0.0001901063020639735 0.000129119647347419 -9.346629519056927e-5 0.0001089536323277873 -8.414815817301496e-5 -6.878122655249907e-5 0.00011476142538671021 2.5774932176169152e-5 2.716672282334325e-5 -0.00013485445230541313 -6.620838245650368e-5 -2.0022751748150545e-5 5.345607645070427e-5 -8.468366865343956e-5 1.8501963161554515e-5 0.00024532104146859137 -6.328035521361436e-5 3.450370935100652e-5 9.064670451410298e-5 -8.746475366173905e-6 -0.00016886933587299046 -1.2988972564167291e-5 -9.370255281025113e-5 0.00013798969451530224; -0.00018991388256559764 -0.00015737312256257065 3.386921693439998e-5 -0.00010524916629670745 0.00011046701351282085 8.557014174861234e-5 -8.938487931751507e-5 -8.320743123200454e-5 -4.1592067788733385e-5 0.00010970035313496888 -5.558658645619961e-5 -5.0030028575712604e-5 0.000128451449057175 0.0002297982044574723 -7.5547288503876e-5 3.0879729677544874e-5 -0.0001350223120943714 -1.4023250123144339e-6 8.267444886136106e-5 7.285988780771783e-5 -4.9362554052060964e-5 -0.00012236753005433985 8.822663025332404e-5 7.237311169141258e-5 -0.0001345252768778301 6.767110411878077e-5 -4.13089129842436e-5 9.496937658972458e-5 0.00011224394060540414 0.00013010019560446325 5.9151361568218925e-5 9.011461756293853e-5; 0.0002612213681253153 6.519342030271448e-5 0.00017564974739556723 -0.000121354871519232 0.00018865991630929592 -1.791108690909851e-5 -4.114718133011583e-5 2.513494080388829e-5 -0.00014201751430178704 0.00021135755712503178 -4.603548252152282e-5 -0.00014779642908494358 -0.00012188379725799448 0.00010874236220600249 4.584341120547979e-5 0.0001165289319736455 8.998118180655594e-5 0.0001043798362637397 4.108428008961255e-5 -0.00012106603055386282 -6.304429218532279e-5 2.870117866884202e-5 5.701533853176947e-5 3.842868288969605e-5 0.00014743231775113649 0.00010048497977308201 -0.00011572827163279024 -6.563379822424328e-5 -7.495557677117813e-6 0.00016679055595644786 -7.305955325379126e-5 0.0002907725535599935; -0.00011219793912811133 -6.767330586067743e-5 1.9881042422763414e-5 -2.3597817456778264e-5 -0.00011968919960404946 0.00011282493261763881 -4.1010137177963e-5 3.078440323964844e-5 3.144226622329115e-5 -9.677347651072498e-5 9.033923317124283e-5 9.42101444853867e-5 -0.00020543407543651792 3.8862873512826967e-5 9.146377607624973e-5 -6.794662720846021e-5 -7.771676854086825e-5 -6.507989990847042e-5 -4.057195718256555e-5 4.5961028550788174e-5 -4.256117125247247e-5 -0.00011893089930148377 -2.7056360235143994e-5 3.236213644062296e-5 2.4556520034945868e-5 8.502848988076723e-5 7.194352412274537e-5 0.000171998207693957 3.273229350828508e-5 -0.0001509504242529709 2.189479102174902e-5 -0.00015362119546249235; -3.5768077600985364e-5 2.482647939771536e-5 9.634952473290514e-5 -0.00016278781661228906 -0.0002546210537052217 0.0002673961924069973 -0.00013265676645958033 -7.749621750441867e-6 5.376146034406848e-5 -0.00012247092056483067 -0.00015689296672050112 3.1036210907154485e-5 1.1598592004266594e-5 -9.339676873921551e-5 0.0001177434085316429 7.667873321635962e-5 -0.0001095151976419181 -0.0001588856186803939 -5.090661691199416e-5 4.241211606724439e-5 0.00011685416281991473 -3.2657307406632426e-5 -5.824830916521611e-5 -0.00027397204505675925 8.054710522130491e-5 -5.6436843101872e-5 4.8813692751106955e-5 -1.3434124617086475e-5 3.692022326005875e-5 0.00013089720484870009 -3.95202480419961e-6 7.774928851588358e-5; -9.3963520407743e-5 -2.443446944660203e-5 -8.512887077366219e-5 5.127441975211543e-5 -4.008867042564368e-5 0.00010496882962335708 -1.5078318279009573e-5 -6.881360538962346e-5 -1.9796053743522285e-5 -6.239284200552979e-5 -8.097880049783292e-6 0.00016096713873132158 1.472442338574015e-5 -0.00015841551728506433 -0.00010484092424583795 5.3022774159142826e-5 -7.684427625925278e-5 -0.00011985508106143268 -3.146938549896711e-5 -9.600915589097173e-5 -8.430048117141123e-5 -2.133633031468108e-5 3.605549929938147e-5 1.3535792027068559e-5 0.00016960899731649867 0.00010840722163713407 3.2378088439984124e-5 0.00018567858999169276 6.847668621496082e-5 -3.317229791271584e-5 -4.641926286702094e-5 -0.0001426105657402142; -6.283676257990258e-5 0.00010538404684811646 7.887323671784352e-5 -0.00010322587250666612 6.67154389626322e-5 -3.462860614345359e-5 9.756325097586878e-5 -5.101799534568555e-5 0.00010803143312987007 3.324908983958429e-5 -9.340942698000395e-5 -4.0783193170139903e-5 -0.00014906229544831354 0.00012349028919888565 4.937044003927225e-5 -4.666097002737147e-5 -4.597320285793414e-5 0.0001192218559404836 0.00012237907584821786 -0.000114585665058623 0.00016073757514422675 -7.152456883742465e-5 7.535007979785065e-5 -8.731181797739101e-5 -5.46439033390948e-5 5.119453718723317e-6 -9.514587784855852e-8 0.00020734850922199028 -0.00015964226541509757 -2.1269729685084625e-5 1.4697046837315512e-5 -5.2573354976997225e-5; -6.503235440399729e-5 -3.156559055137586e-5 -3.3274046462269877e-6 -7.377263486782183e-5 -5.33481778337378e-5 -0.00010747372093541753 -1.7245588887798764e-7 -1.4069916550257574e-5 0.0001744069178595565 3.2710215374331674e-5 -6.154976094773035e-5 7.095855303409924e-5 3.156971629428616e-5 8.505634664393229e-5 0.0002652614804457304 4.607528367265026e-5 3.5618812172426856e-5 9.296569819630706e-5 -0.0001965250392903245 0.00013233369295554214 -2.67258909378596e-5 6.179385868664754e-5 -6.108901820779062e-5 -6.279805333983472e-5 3.904393647749744e-5 3.7304789794830495e-5 8.072327377808406e-5 2.091727778966387e-5 -3.47711191796793e-5 -0.00014179402009981042 0.00022140497653813675 0.00015520927459069899; -0.00012261647562842286 6.144304578777871e-5 1.9916598184533307e-5 -4.808916920443455e-5 8.013576670203808e-5 -0.00010046345479200991 4.7296303173094366e-5 0.0002545386080688663 3.3616957161608874e-5 -0.00014271888422660135 -2.448114965806306e-5 -2.433224172851958e-5 9.679084841048889e-5 -0.0001409262046857874 -2.8430461234496154e-5 -6.671747278563387e-6 -1.2368260153049308e-6 -5.742001004397346e-5 -0.00010576898122086604 2.3234864138100186e-5 0.00014536733150601882 -8.072698320608323e-5 -3.4148885525034084e-5 -0.00013207295042808577 0.0001559404038649851 -5.558967742250698e-5 -2.9654724776595607e-5 7.42961195006834e-5 3.521430270633276e-5 -1.5086213998405115e-5 1.6838866164439052e-6 0.00017133496229656547; 6.766585057608279e-5 2.19197295443853e-5 -4.5093597226008295e-5 3.339922615101053e-5 -6.909675220253408e-5 -9.008730056688695e-5 -0.00022932955697637533 6.297629216692843e-5 5.070947510080837e-5 0.0001062185961125189 -5.9602009503799226e-5 -7.244944798753158e-5 -2.2596725546685484e-5 1.363592265238968e-5 -0.00010539940784678371 3.0058266485645012e-5 -9.32034753372455e-6 -3.088113452417343e-5 9.251563380309093e-6 -0.00010589335078133783 -4.944306282436029e-5 9.347067453905575e-5 2.59388663085518e-6 9.473364981363959e-5 7.894661167643459e-5 1.9095171875739087e-5 5.639328220400479e-5 1.3619849152525022e-5 0.00010556971166057 -5.960955467184575e-5 2.7841773327742385e-5 3.4454698390397e-5; -8.281226464563182e-5 -0.00011604359306178433 7.29035310538758e-5 4.801756595012639e-5 7.374471178960541e-5 4.984796405521097e-5 -4.434493157563066e-5 -3.593646024486219e-5 0.00012620695642878195 -0.0002179751608312849 -0.00013666792065154438 0.00011570495562379032 9.181513723583532e-6 7.660069977615681e-5 -6.178271542317033e-6 -1.1080094646181457e-5 2.6454263297331377e-5 -5.715741645268203e-6 0.00019925585374918573 7.492455207876816e-7 0.00015903767789465114 -7.300284554248901e-7 -0.0001069156859560712 0.0001231812203510928 1.8329428482181216e-6 -2.5697487126613327e-5 2.658035382379477e-5 -2.0122199112243184e-5 -0.00014835734318724902 -1.907594187317565e-5 0.00010018907787495124 -3.779219687233907e-5; 3.0381760479818335e-5 4.1700163746888394e-5 -3.9766041554807763e-5 -0.0001971890315691186 -7.877953710715937e-5 -3.1170825495914356e-5 -6.0215400553837425e-5 -0.00010063617181015997 2.4646737910975814e-5 0.00013540362026093924 4.3819821184905166e-5 -5.851512931270365e-5 9.889655386347415e-5 4.016353243047565e-5 3.5474952637619396e-5 0.00014760992682555786 -1.7244488008267585e-5 2.4126839816622856e-5 -6.44637194274165e-5 -3.916265002784233e-5 -1.9444070021685206e-5 2.050920822500446e-5 -0.0001321325408572449 0.0001492163127476176 3.874831865674923e-5 4.61771441404067e-5 -1.6366905656589943e-5 -6.37374624421996e-5 3.6863179884553816e-5 -0.00015622862072644633 0.0001286223404530263 -6.921222764516085e-5; 0.00011426792431698471 -1.1881223628421028e-5 -5.743841077823641e-5 -5.154618023120384e-5 -1.9832709341843147e-5 3.3241918088927265e-5 3.8808998109270694e-6 -1.261872924410335e-5 4.8771034442646344e-5 4.515085626583999e-5 0.00024443623939836116 0.00013097204597975156 -1.9195601027282168e-5 -4.789256539120513e-5 5.225646389911943e-6 -4.689100799773412e-5 5.8699082245175536e-5 6.808745044359184e-5 0.00011012326964951215 -0.00013127835021045512 -3.6729340110152464e-5 -1.7497208802236452e-5 1.7144604015188676e-5 -7.418047655996658e-5 -1.976166143471956e-5 7.44714193100235e-5 5.073193412349317e-5 -0.00014780445101827952 5.309369541314912e-5 0.00023372189704637278 -7.199993027065711e-5 9.098436518647696e-5; -0.00011346888861144181 -6.949473246921401e-5 3.856546898614005e-5 -2.4934232368744057e-5 5.946497693086857e-5 0.00011604997196709466 0.00016227626254559178 -0.0002037716352961956 1.293955048352677e-7 4.251151568439288e-6 -5.7714336608725405e-6 -8.03743991301365e-5 -9.655654359367279e-5 1.2816101462622196e-5 -0.00015810647384487537 0.00010147678911461968 -5.5986102251240775e-5 -2.2851416742103224e-5 2.4247322999967013e-5 -0.0002529838988440172 -6.8538292171894736e-6 9.62035535937691e-6 -0.00018944418336016054 -3.871332448315868e-5 1.7418580301544886e-5 -0.0001359462352585076 -5.8182302763351234e-5 -0.0001396953324902839 -1.651905806202581e-5 9.977937368685262e-5 -0.0001698688596835359 -5.3072779718689574e-5; 0.0001665917783953843 2.55261160760759e-5 -4.5622515756251e-5 4.952263168532273e-6 1.5503921019994497e-5 6.894197343863901e-5 3.452620279641123e-5 -0.00018490239978054222 7.067792961389196e-5 7.809145375868188e-5 0.000163283909785244 3.1647961293498955e-5 -5.638187082003411e-5 -1.810713780504896e-5 2.7463721778629413e-5 -5.7706862719343887e-5 0.00011514554847745515 -5.932460277326087e-5 1.4741193297238248e-5 5.4910456526711875e-5 -0.00011076217430162945 0.00014327150478274934 5.155987440739169e-5 -5.9466669483656524e-5 -4.4826165889318325e-6 -0.00013497905689509229 2.142949182380659e-5 0.00016063907008866216 -0.00013517397979957583 -1.2969849706125831e-5 -0.00012878180365295764 -7.320985110927106e-5; -5.042772437876127e-5 0.0001275519995790707 0.00017664694558631617 2.7074850261914338e-5 3.106326647254483e-5 -7.74154080470282e-5 2.635712615598837e-5 -8.919512521673043e-5 -0.00017936804698977647 0.00019517416756868123 -0.00011380223196896012 3.0022571838966304e-6 0.00017422297485906614 0.00013984085915878325 -0.00011179627228260063 -0.00010120334383530563 4.170433551979328e-6 -6.724688152400105e-6 0.0001161084680309327 0.0001462663446393608 1.2365889854907577e-5 0.00010345797880587665 -0.00016086461738881174 -0.0001705885563903679 -0.0001102309445893581 3.697895470692744e-5 8.512637094396998e-5 -1.2211559184589694e-5 4.6733045276841014e-5 -0.00017336383945506107 -0.0001473083733041962 0.0001273211625478037; -7.243104161478432e-5 7.713913103782039e-5 8.269308048749428e-5 -0.00014968608999641505 0.0001436448106074939 -0.0002365851655371287 3.591411908257105e-5 -5.2789452154122295e-5 0.00015175721230900035 2.2465691956880045e-5 5.9585959691160185e-5 8.461298742315033e-5 9.848812945403256e-5 -0.0001926959343721874 -0.0001250911704731908 -0.00012281803034728286 -0.0001078010286322587 7.192023943627433e-5 0.000172423405758871 0.0002246582476187692 8.195900184784369e-5 -7.670080637342013e-5 -4.820091314047247e-5 -0.00022904258773213078 -9.887839811870703e-5 0.00016672483209969612 -4.848141222044506e-5 1.1754691088405738e-5 1.9050946805342948e-5 -3.781542588274762e-5 6.400038140110856e-5 -0.0001674742929639084; -1.5718237253732974e-5 -3.029841412045628e-5 -5.389333148616341e-5 -5.8591046243802755e-5 0.00023402575913804662 0.00017317589242464096 1.691762679127582e-5 -7.732891811303509e-5 -0.00010154806352931883 -0.00024931364589847883 -1.6614717456900966e-6 -5.866733829736305e-5 -0.00015228052928125427 -8.102471500585184e-6 4.540883980678487e-5 -3.4119822236083603e-6 -0.00015194307036710538 8.762071535909827e-5 -0.0001031089892682026 -0.00011814742595436026 -2.7875323784080507e-5 6.841947776720997e-5 4.276651031560508e-5 -5.842513136140562e-6 -0.0001353822195857191 6.715896904222827e-5 7.279836006217334e-5 7.211954959655306e-5 -5.961146291940447e-5 -0.00018192895061349264 6.413560492672312e-5 0.00014921035778696153; -6.717857539158449e-5 7.404958928692425e-7 -9.6863063645719e-5 3.285747081420191e-5 -0.00012890329597254894 -6.708590151945276e-5 0.00014147852127237303 -0.0001305714547747568 -3.343822461872532e-5 5.130395692990346e-6 3.551064875822138e-5 9.943664026220644e-6 -6.39393481655812e-5 1.1977096819035557e-5 -2.2000912195357508e-5 -2.3118848530861594e-5 9.541789697280118e-5 -2.527881841203749e-5 5.7325130200776726e-5 -2.714791733567362e-5 2.024952411332228e-5 1.1267958343108025e-5 3.264424523835579e-5 5.1655238366517184e-5 -0.000218261339596642 -5.566707912337185e-5 6.775627773052191e-5 6.091257748223672e-5 -0.00010114409158675223 6.538098315970862e-5 1.8913150797313198e-5 5.4171104438435324e-5], bias = [1.3745099021056385e-9, -1.1260657321171465e-9, -1.2990523366594278e-9, -6.822369112170072e-10, -1.3866636763186096e-9, 1.4900486525651603e-9, -2.2875881767967213e-10, 5.560952440616665e-10, 3.934339069696598e-9, 2.71113209949016e-9, -1.5182534074884399e-9, 4.6711729322189006e-9, -1.7260754996722978e-9, 1.4785044515976402e-9, 1.0530521108990596e-9, 4.885512083411856e-9, -1.661359295841981e-9, -2.0602105719224092e-9, -1.3308409343860252e-9, 7.284066649815007e-10, 2.70861269886842e-9, 2.030807415757461e-10, -1.0412645105794088e-10, 7.858940672540973e-10, -4.0843600322787367e-10, 2.480618017901638e-9, -4.906195594512953e-9, 6.816095340207381e-10, 3.0075468110699015e-10, -8.116943567092488e-10, -2.085133585839897e-9, -1.1357592255067505e-9]), layer_4 = (weight = [-0.0006796229187517447 -0.0006885219614447404 -0.0008209599882657192 -0.0007226516234246453 -0.000769996047746865 -0.0005627981479416057 -0.0007300340080805408 -0.0008168459445804017 -0.000649263921592931 -0.0007728948282439503 -0.0006427888655592365 -0.0007979372027114778 -0.0005221582868121625 -0.0006872430203589204 -0.0007979230016248345 -0.0005543949368446674 -0.0007278524533554609 -0.0007005207312868305 -0.0007336844382200704 -0.0006166533290428479 -0.0004916600480851606 -0.0007460684544902778 -0.0007390189871688624 -0.0006361657063596585 -0.0007075376132101468 -0.0007480060010151403 -0.0008161744050781929 -0.0006121170182111219 -0.0007152020422333129 -0.000648972628172796 -0.0005441629667900748 -0.000792745891967862; 0.00017952940183113723 0.0003263865559659372 4.3860354666672045e-5 0.00016490562463469918 0.0001070556538599047 0.00021808813097379853 0.00020534936538392078 0.00024465464310896004 0.0002938789327785234 0.0002635788589920523 0.0003024008234304869 0.0002120558978664538 0.00017259671007966042 0.00044922471531752447 0.0003219783451179133 0.00018169264059199883 0.00024183138031107102 0.00021327784117380018 0.00032342879536446126 0.00024166817490903594 0.0002486537650108981 8.495466962447409e-5 0.00038475555567897535 0.0003363040256151035 0.00030251060351534965 0.00011105446102090704 0.0002857472157684033 0.0002766641736059404 0.000261832904193463 5.695763907197875e-5 0.00012688831300815284 0.00028002498669712296], bias = [-0.0006986574642761802, 0.0002315581436121987]))Visualizing the Results
Let us now plot the loss over time
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")
lines!(ax, losses; linewidth=4, alpha=0.75)
scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)
fig
endFinally let us visualize the results
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
s3 = scatter!(
ax,
tsteps,
waveform_nn_trained;
marker=:circle,
alpha=0.5,
strokewidth=2,
markersize=12,
)
axislegend(
ax,
[[l1, s1], [l2, s2], [l3, s3]],
["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
position=:lb,
)
fig
endAppendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.12.5
Commit 5fe89b8ddc1 (2026-02-09 16:05 UTC)
Build Info:
Official https://julialang.org release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 4 × Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz
WORD_SIZE: 64
LLVM: libLLVM-18.1.7 (ORCJIT, icelake-server)
GC: Built with stock GC
Threads: 4 default, 1 interactive, 4 GC (on 4 virtual cores)
Environment:
JULIA_DEBUG = Literate
LD_LIBRARY_PATH =
JULIA_NUM_THREADS = 4
JULIA_CPU_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0This page was generated using Literate.jl.