Training a Neural ODE to Model Gravitational Waveforms
This code is adapted from Astroinformatics/ScientificMachineLearning
The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl
Package Imports
using Lux,
ComponentArrays,
LineSearches,
OrdinaryDiffEqLowOrderRK,
Optimization,
OptimizationOptimJL,
Printf,
Random,
SciMLSensitivity
using CairoMakieDefine some Utility Functions
Tip
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector
function one2two(path, m₁, m₂)
M = m₁ + m₂
r₁ = m₂ / M .* path
r₂ = -m₁ / M .* path
return r₁, r₂
endNext we define a function to perform the change of variables:
@views function soln2orbit(soln, model_params=nothing)
@assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
if size(soln, 1) == 2
χ = soln[1, :]
ϕ = soln[2, :]
@assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
p, M, e = model_params
else
χ = soln[1, :]
ϕ = soln[2, :]
p = soln[3, :]
e = soln[4, :]
end
r = p ./ (1 .+ e .* cos.(χ))
x = r .* cos.(ϕ)
y = r .* sin.(ϕ)
orbit = vcat(x', y')
return orbit
endThis function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d_dt(v::AbstractVector, dt)
a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
b = (v[3:end] .- v[1:(end - 2)]) / 2
c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
return [a; b; c] / dt
endThis function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d2_dt2(v::AbstractVector, dt)
a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
return [a; b; c] / (dt^2)
endNow we define a function to compute the trace-free moment tensor from the orbit
function orbit2tensor(orbit, component, mass=1)
x = orbit[1, :]
y = orbit[2, :]
Ixx = x .^ 2
Iyy = y .^ 2
Ixy = x .* y
trace = Ixx .+ Iyy
if component[1] == 1 && component[2] == 1
tmp = Ixx .- trace ./ 3
elseif component[1] == 2 && component[2] == 2
tmp = Iyy .- trace ./ 3
else
tmp = Ixy
end
return mass .* tmp
end
function h_22_quadrupole_components(dt, orbit, component, mass=1)
mtensor = orbit2tensor(orbit, component, mass)
mtensor_ddot = d2_dt2(mtensor, dt)
return 2 * mtensor_ddot
end
function h_22_quadrupole(dt, orbit, mass=1)
h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
return h11, h12, h22
end
function h_22_strain_one_body(dt::T, orbit) where {T}
h11, h12, h22 = h_22_quadrupole(dt, orbit)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
h11 = h11_1 + h11_2
h12 = h12_1 + h12_2
h22 = h22_1 + h22_2
return h11, h12, h22
end
function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
# compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2
@assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"
h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
@assert mass_ratio ≤ 1 "mass_ratio must be <= 1"
@assert mass_ratio ≥ 0 "mass_ratio must be non-negative"
orbit = soln2orbit(soln, model_params)
if mass_ratio > 0
m₂ = inv(T(1) + mass_ratio)
m₁ = mass_ratio * m₂
orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
return waveform
endSimulating the True Model
RelativisticOrbitModel defines system of odes which describes motion of point like particle in Schwarzschild background, uses
where,
function RelativisticOrbitModel(u, (p, M, e), t)
χ, ϕ = u
numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
denom = sqrt((p - 2)^2 - 4 * e^2)
χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = numer / (M * (p^(3 / 2)) * denom)
return [χ̇, ϕ̇]
end
mass_ratio = 0.0 # test particle
u0 = Float64[π, 0.0] # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4) # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize) # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, eLet's simulate the true model and plot the results using OrdinaryDiffEq.jl
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)
axislegend(ax, [[l, s]], ["Waveform Data"])
fig
endDefining a Neural Network Model
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but in case you do use them, make sure to mark them as const.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,
const nn = Chain(
Base.Fix1(fast_activation, cos),
Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)((layer_1 = NamedTuple(), layer_2 = (weight = Float32[1.3351865f-5; 0.00011501337; -9.139098f-5; -9.221634f-5; -1.766013f-5; 8.164934f-5; -0.00013934336; -2.458239f-5; -0.00015962713; 0.00020397389; 0.0001726338; 3.0771716f-5; -7.927738f-5; -5.487881f-5; 0.00012225012; 5.6274035f-5; 7.3977186f-5; 9.0784786f-5; -4.3031443f-5; 6.3171f-5; 7.6221906f-5; 0.00011198696; -9.725841f-5; -5.693377f-5; -0.00012532504; -8.7009576f-5; -5.1365096f-5; -3.5635174f-5; 7.384907f-5; 0.00012852688; 1.4420486f-6; 0.00015116259;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-2.5368336f-5 0.00017731966 0.00010165056 -0.000109757304 -0.00011131852 -5.623552f-6 -0.00011260631 2.2029105f-6 1.23656355f-5 1.2648378f-5 -0.00026055903 0.0001276019 -0.000103812905 -8.535584f-5 0.00019820164 8.1967795f-5 9.905469f-5 -0.00013109986 9.921065f-7 0.00011903567 -6.564398f-5 -0.00012102384 7.493884f-5 -0.00018714645 0.00010146068 -0.00019731057 -6.291555f-5 9.215899f-5 0.000102600905 -0.00021959105 0.00012352465 0.00012218642; -6.609272f-5 0.00011682667 -0.00016060835 5.7564754f-5 7.986149f-5 6.2703024f-5 8.826461f-5 -0.00017981675 -0.00012241951 -3.6293222f-5 2.3702243f-5 -0.00014455583 -6.779959f-5 6.329955f-5 -0.00012561305 -9.716504f-5 2.5910615f-6 -0.0001034874 9.733339f-5 8.626836f-5 -7.586784f-5 0.00013393645 0.00018110708 7.100105f-5 -4.6872858f-5 0.00010065484 2.9689869f-5 -4.1181458f-5 -0.00019219187 6.8434245f-5 0.00017971713 0.00015173803; -6.958085f-6 5.255087f-5 0.0001747039 -0.00022974165 1.1100002f-5 -5.52065f-5 6.512339f-5 -0.00016834146 -0.00014626009 -4.42196f-5 7.479232f-5 6.506744f-5 5.2508814f-5 0.00010151165 -0.0002038698 0.00018379043 -1.8371586f-5 -0.00015654603 7.259918f-5 5.4536685f-7 -0.00014395532 -0.00020608857 -0.000149267 -1.3108302f-5 6.0858434f-5 -5.005461f-5 4.985169f-5 5.4797506f-6 -2.9039576f-5 -0.00010409181 -0.00011546241 -6.790035f-5; 6.179134f-5 8.55433f-5 -0.00014653533 8.744055f-5 -0.00023121489 0.00015660044 -7.902376f-5 -7.0291455f-5 0.000114070404 0.00012974959 9.7069686f-5 5.5885143f-5 7.3663716f-5 -3.875807f-5 -7.487087f-5 -7.1437775f-5 0.00014767131 -0.00013123425 -3.7446247f-5 -6.455883f-5 -7.850997f-6 3.053576f-5 -0.00012618124 0.00019729826 6.3013154f-6 8.545343f-5 -3.7794427f-5 6.063066f-5 2.3576384f-5 1.1162721f-5 5.779429f-5 9.6468684f-5; -3.577357f-5 -5.7625607f-6 0.0002122046 8.4352505f-5 -0.00016385996 3.242179f-5 -1.278502f-5 -1.6889762f-5 0.00015745097 6.290422f-5 5.349753f-5 -4.855499f-5 -1.3478534f-6 4.596686f-5 -6.919534f-5 7.60175f-5 -3.392668f-5 -3.7508034f-5 -9.381199f-5 -5.443022f-5 -5.016074f-5 -6.85801f-5 6.6341796f-5 9.729811f-5 6.226241f-6 0.00010193281 0.00013239718 5.026363f-6 -0.00019899457 -0.00015828975 -0.0001770425 -1.361882f-5; -5.4724555f-5 0.00028814058 0.00010317065 -3.373396f-5 -2.1652644f-5 3.4547233f-5 5.4372496f-5 -0.0001782397 -0.00020604001 -0.00015564152 0.00010062197 7.9451325f-5 -0.00012328087 5.17915f-5 5.0095787f-6 -2.3068997f-5 7.308341f-5 -0.0001300755 -3.4641784f-5 3.3137156f-5 -4.185589f-5 0.00011585339 2.2297725f-5 4.196892f-5 -7.472117f-5 -9.7251075f-5 0.00012577767 3.3867735f-5 -9.86988f-5 -0.0001441409 6.0759266f-5 -9.763168f-6; -1.2933604f-5 -7.13899f-5 1.1197149f-5 -9.500735f-5 3.310338f-5 0.000115232106 -0.00016026551 -5.2401938f-5 -0.00011535539 -0.00021496206 -3.1160896f-5 -2.0335201f-5 3.461271f-5 -3.0745447f-5 -4.423107f-5 -3.7222864f-5 -0.0001953056 9.156658f-5 -0.00010754825 -0.00011611512 -1.42426725f-5 -9.2751136f-5 -7.684298f-5 6.928847f-5 9.2506736f-5 -5.5220407f-5 -1.7774626f-5 -0.00017286773 -9.738207f-5 1.2563727f-5 0.00012832256 4.7943446f-5; 2.9047713f-5 4.660118f-5 -3.909855f-5 -7.739065f-5 -3.526812f-5 -4.7962232f-5 -6.650388f-6 -6.835183f-5 -2.2527367f-5 -6.3795585f-7 -0.00014795306 -7.080644f-5 0.0001566654 7.6176046f-5 3.414899f-5 2.7043649f-5 2.767983f-5 0.000119860226 4.4657292f-5 -6.278584f-5 -9.072174f-5 -1.6716791f-5 -0.000120494544 1.2997401f-6 8.628273f-5 1.7057071f-5 5.5329423f-5 3.661774f-5 0.00019448414 2.3860866f-6 -0.00022672329 0.00010548425; -2.933228f-5 -6.642776f-5 0.00020103116 0.00018618954 0.00018992189 3.16669f-5 -8.73525f-6 -1.32103f-5 -9.613073f-6 -3.4680415f-5 0.00020607722 0.00010024505 3.9156923f-5 1.4552569f-5 0.00022038457 -0.00011287938 2.8062552f-6 -5.96939f-5 4.4281853f-5 -5.338706f-5 5.35995f-5 2.8656921f-5 -0.00014901694 3.5704124f-6 -1.6054908f-5 -2.2702257f-5 7.671327f-5 -8.7214554f-5 -8.492343f-5 -5.111188f-5 9.900603f-5 -1.8655626f-5; 8.755795f-5 9.1728994f-5 -0.00016132387 0.00018059694 -7.917525f-5 -7.592465f-5 -0.00018475509 -4.635819f-5 3.7193895f-5 -3.571101f-5 0.00014863361 -3.847424f-6 5.8630918f-5 -2.7369713f-5 -2.0845397f-5 3.7119796f-6 3.1767828f-5 -0.00013535337 0.00012428216 0.00011146658 -0.00015064268 9.671728f-5 -5.667151f-5 0.00010459404 0.0001536151 2.3375638f-5 -7.074227f-5 -5.2809624f-5 -1.4586138f-5 -2.3679107f-5 -0.00011707722 8.046575f-6; -2.6526373f-5 9.505127f-6 -0.00015520868 0.00012412784 -0.00028683312 8.2342805f-5 -0.00010570294 -0.00014426227 0.00017996947 1.8002276f-5 -0.00012336434 0.00010729047 2.7203274f-5 0.00010018347 -4.2602762f-5 7.851821f-6 4.8934682f-5 2.812206f-5 -9.09068f-6 -6.062577f-5 -0.00025894676 0.00011298299 0.00021754506 -4.7110905f-5 4.2766253f-5 2.1820952f-5 4.0771763f-5 -6.818283f-6 1.2622204f-5 1.1660895f-5 8.064015f-5 0.000111124726; -2.3752758f-5 -4.4248418f-7 -0.00013225833 -3.2603264f-5 0.0002495613 -4.553239f-5 0.00018672427 5.5150507f-5 0.00019436842 -5.4505654f-5 -4.331083f-5 2.9246792f-6 8.7612345f-5 2.8294537f-5 -2.1536585f-5 3.0386565f-5 -1.1066389f-5 -0.000108497894 -0.000265124 6.4691034f-5 -0.00012871202 0.00019676783 9.4230985f-5 -5.871271f-5 -4.247195f-5 1.9133773f-5 0.00015740185 8.5031745f-5 8.3345556f-5 -0.00014200433 8.8961606f-5 -1.4566702f-5; -0.00012903017 2.1368765f-5 2.8619263f-5 -4.105433f-5 -0.00024778108 2.1839833f-5 -0.00018311726 -3.5500107f-5 -8.00983f-5 -0.000108588894 9.482476f-5 4.745248f-6 3.8759892f-5 2.1132737f-5 3.2507953f-6 -0.00012271207 9.405543f-5 2.16368f-5 -1.122119f-5 0.00011144033 0.00010823025 -1.9468465f-5 -6.6453717f-6 7.024206f-5 -0.00012501245 -8.2316656f-5 -2.2230932f-5 -7.657273f-5 -9.7192875f-5 -2.5624435f-5 1.730028f-5 0.00017808947; 0.00013928708 4.750588f-6 2.9991414f-5 -3.0929452f-5 9.081602f-5 -1.30693315f-5 -3.6892423f-5 -4.7272144f-5 -1.1561063f-5 9.718215f-5 -0.00012686213 5.40502f-5 0.00027864927 -6.114582f-5 -0.00016060827 0.00015771827 8.731322f-5 -3.199454f-5 -0.00016640403 4.356646f-5 0.00014216782 -8.629903f-5 0.0001670911 0.000121039295 6.0157494f-5 -2.2540073f-5 4.543145f-5 0.00012480118 6.306689f-6 -5.113061f-5 4.1785657f-5 -3.9231687f-5; 8.47045f-6 -0.00023209014 -5.7277884f-5 -9.594935f-5 0.00013286981 4.3049655f-5 -5.6311066f-5 -0.00015826299 0.0002359442 -9.880111f-5 -2.3683899f-5 0.00032224736 0.0002462383 0.00010873708 -0.00011158311 -3.7788664f-5 2.04421f-5 0.00019315412 6.0827475f-5 0.0001753617 6.004933f-5 -6.549711f-5 4.438263f-5 0.00010341042 -0.00012719438 2.3438766f-5 0.000102514816 -7.5746204f-5 -3.9109324f-5 -0.00013781677 1.2413814f-5 -3.9322982f-5; -5.7096146f-5 -6.9862384f-5 1.6129477f-5 0.00016278638 -0.00025026916 0.00015444857 0.000121761586 0.00018290618 -1.9657964f-5 -0.0003218335 -1.2372802f-5 5.8558064f-5 -0.00029009642 -2.8144257f-5 -1.773568f-5 7.4990225f-5 -4.651468f-5 -2.7966116f-5 -0.00020282199 -6.861452f-5 -2.5189795f-5 2.8282146f-5 5.4256298f-6 -0.00023961997 -9.4247516f-5 0.00010382471 -0.00011942697 0.00016883809 9.4177005f-5 -7.709966f-5 1.1788601f-5 2.9910203f-5; 4.5674846f-5 3.7014095f-5 0.00020801701 -4.2401232f-5 -9.683505f-5 9.433496f-5 1.8090446f-5 -6.6523426f-5 3.503124f-5 8.942931f-5 -9.859641f-5 -4.976704f-5 0.00021783402 1.8479372f-5 7.6741795f-5 -5.3876254f-5 5.0500494f-5 5.4169072f-6 1.1322893f-5 6.0995462f-5 -6.07602f-5 2.0567995f-5 7.319299f-5 1.107825f-5 -0.00012071332 0.00011953051 0.000118301316 9.7018135f-5 3.8231577f-5 4.09694f-5 6.0904455f-5 9.706034f-6; 5.8920738f-5 0.00011633883 -5.1182167f-5 -8.853518f-5 2.8556606f-5 0.00015437191 -2.001385f-5 -8.381056f-5 -6.5271255f-5 -0.00017186538 7.679154f-6 9.39485f-5 -0.00024115472 -8.472139f-5 -0.00019332588 4.1934636f-5 -5.4246125f-6 -0.00023760051 -1.9766609f-5 0.00019113523 -4.786237f-5 -0.00013867357 -0.00016649257 -6.3429484f-6 -4.4755398f-5 0.00013311785 -2.4959782f-6 8.913349f-6 -0.00022963597 -0.00012571126 0.000115255265 -0.00016426946; 4.8174676f-5 3.4056924f-5 -0.00015444154 7.2432995f-5 -8.20147f-5 3.271061f-5 0.00021853125 -8.3893705f-5 -0.00011434591 5.3574463f-6 9.1034395f-5 0.00012973184 0.00021806375 -6.743133f-5 -6.041677f-5 -1.0851408f-7 -0.00016395425 0.00015871445 0.00021568629 0.00012319373 -2.0550237f-5 1.23822865f-5 -8.0851176f-5 -1.5744377f-5 -8.084813f-5 0.00015648632 3.1118472f-7 -0.00015950187 8.76384f-5 -0.00019804596 -1.9030178f-6 5.3146043f-5; 4.995715f-5 0.00016001261 0.00017501928 -0.0001749246 0.0001392355 0.00023286034 -7.6920136f-5 1.1300129f-5 -4.0072177f-6 2.9113375f-5 -0.00017888384 0.00011802518 -5.491966f-5 -5.6915826f-5 -6.1372935f-5 -1.4161055f-5 -4.8802176f-5 -8.461665f-6 4.301917f-5 0.0002292568 8.236257f-5 -3.3310625f-5 3.5459612f-5 2.840587f-5 -6.0679984f-5 -1.795399f-5 -8.827761f-6 0.00011801388 -2.9727998f-5 -0.00019221603 -0.0001510439 0.00010837457; -0.00016014562 -0.00014887023 -0.00019473315 9.373317f-6 5.579598f-5 8.380413f-6 7.785983f-5 -8.778523f-5 -4.1336778f-5 -0.00012034475 0.0001020765 -5.3233376f-5 -0.00016944959 -7.2702082f-6 2.2047501f-5 8.6746186f-5 -1.7606719f-5 -9.726099f-5 -4.6141984f-5 7.099878f-5 -2.5543959f-5 -6.475389f-5 0.00012308318 0.00015206638 0.00019220762 4.6534693f-5 7.2598596f-5 0.00011472841 0.0001886122 -3.6183188f-5 -0.00011720764 6.400023f-5; -4.60015f-5 -5.405779f-5 6.0509706f-8 1.7794343f-7 4.41758f-5 3.4379504f-5 -4.123117f-5 -0.00013112617 -8.189107f-5 -0.00016251404 -0.00022023889 8.438981f-5 -4.6710997f-5 -0.00017822228 0.00015179082 0.00013634122 -0.00020969326 0.00010633975 -0.00020339365 5.9011905f-5 6.541735f-5 -5.375423f-5 -2.535583f-7 -3.1422693f-5 4.1038897f-5 7.868364f-5 0.00016214045 6.339378f-5 -9.5922594f-5 -3.17266f-5 -3.0706666f-5 3.9023425f-6; 4.9993898f-5 9.589441f-5 -0.0001195359 -4.5210843f-5 2.4238887f-5 9.084017f-5 6.436822f-5 -2.7926799f-5 -0.00015949662 0.00018012483 4.0122f-5 1.735714f-5 -7.6493336f-5 9.592989f-5 -8.150625f-5 -1.4205364f-5 8.5963664f-5 5.781318f-5 1.8281715f-5 -5.8000518f-5 0.00012248733 -5.4156568f-5 -9.6150725f-5 2.068022f-5 -9.0217814f-5 -0.00010036648 4.465132f-5 -1.0626314f-5 -9.992459f-5 -5.459152f-5 -0.00018442747 3.1357853f-5; 0.00011973175 0.0001916945 0.00014207362 -4.1494797f-5 -3.9416784f-5 -6.389922f-5 -0.00011381287 0.00014392564 7.669612f-5 1.3671706f-5 8.879458f-5 0.00016315757 2.7548907f-5 -9.379093f-5 -0.00017071926 -5.1208102f-5 0.00011317268 -8.9789155f-6 0.00012603031 -0.00014454186 3.3915876f-5 -3.1175012f-5 -0.00017320942 8.2073304f-5 -4.421612f-5 -4.4853616f-5 -0.00011637147 -9.324675f-5 -6.0522743f-5 8.996828f-5 -4.704273f-5 3.3644028f-5; -4.702452f-5 -7.1314826f-5 -0.00014585104 0.00011563411 -7.834357f-5 0.0001724104 7.933897f-5 -2.3135193f-5 -1.8953866f-5 2.8682203f-5 -0.000182928 -4.536547f-5 -7.372791f-5 1.2735239f-5 -6.668765f-5 -0.00019333471 -4.7616213f-5 7.67649f-5 -0.00015772045 -0.00013401345 -2.212751f-5 -0.00014126328 7.0266535f-5 -1.38687365f-5 2.2535412f-5 -0.00010459436 -5.7661564f-5 -0.00010725689 -0.00016179853 -1.4083293f-5 4.9781327f-5 -3.7322316f-5; -5.7350644f-5 -0.00013489733 0.00012394598 9.5001196f-5 0.00018210782 3.3172295f-5 -6.0906033f-5 -0.00010469536 0.00012728311 -0.00010706508 -1.0006662f-5 -0.00012390362 9.8590615f-5 3.917149f-5 0.00018834524 5.831137f-5 7.835843f-5 -1.7456467f-5 0.00010988174 5.1601994f-5 -4.010655f-5 -5.6049514f-5 1.3723392f-5 1.4914289f-5 -8.015519f-5 1.4173082f-5 -4.694603f-5 -5.1908868f-5 -1.8650469f-6 3.5739995f-5 -5.401052f-5 0.0001859828; 0.00017365144 0.00016729659 9.291057f-6 -6.03364f-5 0.00016490284 0.000107774256 -1.4036697f-5 -6.522078f-5 0.00013607177 -9.797696f-6 0.00014712832 -4.16427f-5 -0.00019780837 -3.1197676f-5 -1.5233805f-5 1.3759122f-5 -2.7621049f-5 0.00018195379 3.427445f-5 -9.3116534f-5 -2.250757f-5 -6.0712806f-5 1.7035696f-5 4.898114f-5 0.00010302704 -0.00015498116 -0.00011197166 0.00012610121 -3.8185713f-5 -6.0521434f-5 -0.00011016113 -5.7514484f-5; -0.00011108663 1.3867846f-5 2.3490231f-6 -1.527862f-6 -1.1993402f-5 -0.00023308441 0.00015876026 1.0843682f-5 -1.0984481f-5 0.0001808538 -1.7798655f-5 -0.0001058086 5.289755f-5 0.00018547165 -8.28774f-5 1.98559f-5 8.039206f-5 -8.136577f-5 -2.6690173f-5 4.4096352f-5 -9.151616f-5 -5.4009015f-5 0.00014298887 6.4614804f-5 -6.3861655f-5 -0.0001338502 -1.8112163f-5 -0.00017004622 -8.767964f-5 3.4237743f-5 6.1650004f-5 0.00011232359; -0.00018050332 -0.00018322698 0.00018206229 0.00027969258 -2.2707009f-5 -0.00010263593 6.906611f-5 4.850213f-5 -5.733599f-5 0.00013591717 -0.00010053855 -5.6452234f-5 8.917962f-5 0.0001024644 5.3932607f-5 -0.00017285599 9.209203f-5 7.558389f-5 8.604908f-5 9.399564f-6 -4.04771f-6 -2.2318785f-5 7.923977f-5 9.3322866f-5 -2.5843749f-5 -0.00010412771 6.000918f-6 -1.5172273f-5 -8.020359f-5 -5.683145f-5 1.7112105f-5 4.9578794f-5; 8.557223f-5 0.00014599795 -7.7663026f-5 9.072679f-5 7.175225f-5 -2.579526f-5 -0.00010822117 -5.354313f-7 4.275542f-5 -4.7146946f-7 0.0001834963 0.000117127696 -1.8742834f-5 -7.7562865f-5 -0.00010919036 0.00015509898 -0.00013864005 3.4345587f-5 -0.00015270904 3.9580475f-5 -1.0076068f-5 2.8196784f-5 0.0001384174 -8.800294f-5 8.056892f-5 -5.472531f-5 -5.915061f-5 6.462785f-5 -4.6236117f-5 -4.2346004f-5 -6.0430953f-5 8.7133245f-5; 6.5536456f-6 0.00013152145 5.7669575f-5 0.00013512016 -0.00012858726 0.00025589304 0.00016779434 -1.6804595f-5 5.6187833f-5 -0.00013812879 -1.6328568f-5 4.654199f-5 2.8752904f-5 -5.1763887f-5 2.469241f-6 2.4451225f-5 -0.000172994 3.7997655f-5 -5.572827f-5 0.0002158217 4.5841673f-5 2.9493174f-5 2.5253079f-5 4.417799f-5 -2.6798225f-5 0.00012737853 -0.00021870546 2.7838336f-5 -0.00014749008 5.7180245f-5 2.7787879f-5 0.0001212489; 5.4205935f-5 -5.3786734f-5 -0.00022390642 0.00019462673 -8.051915f-5 4.609251f-5 8.604006f-5 -4.386169f-5 7.103663f-5 -1.6768352f-5 7.588215f-6 0.00021521693 -6.979035f-5 -9.655512f-5 -0.00014713191 6.61219f-7 -4.4684988f-5 1.0195412f-5 7.2643066f-5 7.97317f-5 0.00021705721 -0.00014246801 7.436993f-5 -0.0001288065 1.8596927f-5 -1.1460492f-5 -0.000119551696 -0.00015491391 -0.00013915202 -0.000114962124 1.1451136f-5 0.00011345222], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[-6.230103f-6 -0.00013205009 4.8677866f-5 4.228672f-5 2.9894863f-5 0.0001368446 1.6852106f-5 0.00014462254 6.508936f-5 5.950981f-5 -0.00014181301 -6.320878f-5 1.2744734f-6 1.8871016f-5 4.9656148f-5 0.00010991427 -7.164112f-7 -3.4161596f-5 3.635686f-6 1.8856417f-5 -0.00025963434 -3.2463377f-5 3.8970633f-5 -4.760614f-5 -2.4697973f-5 -0.000109266795 -3.1851658f-5 5.569127f-5 -0.00014360357 -0.00011966185 0.00016400882 -0.00012190631; 9.542966f-5 2.101951f-5 -1.4598776f-5 -2.7698225f-5 7.65815f-5 -2.3961957f-5 2.3551453f-5 -7.548656f-5 0.00012674407 -2.3956054f-5 -3.829975f-6 5.894426f-7 7.1266506f-5 0.00012526025 3.3844513f-5 -8.198882f-5 0.000115822695 6.844636f-5 1.2974841f-6 3.431651f-5 3.5687083f-6 0.00020987593 0.00011534882 -4.0297702f-5 6.9016336f-5 8.014447f-5 -0.00013655261 -0.00011456531 -9.1517875f-5 -3.1551463f-5 0.0001149616 -5.3336717f-6], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))Similar to most deep learning frameworks, Lux defaults to using Float32. However, in this case we need Float64
const params = ComponentArray(f64(ps))
const nn_model = StatefulLuxLayer(nn, nothing, st)StatefulLuxLayer{Val{true}()}(
Chain(
layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
layer_2 = Dense(1 => 32, cos), # 64 parameters
layer_3 = Dense(32 => 32, cos), # 1_056 parameters
layer_4 = Dense(32 => 2), # 66 parameters
),
) # Total: 1_186 parameters,
# plus 0 states.Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses
where,
function ODE_model(u, nn_params, t)
χ, ϕ = u
p, M, e = ode_model_params
# In this example we know that `st` is am empty NamedTuple hence we can safely ignore
# it, however, in general, we should use `st` to store the state of the neural network.
y = 1 .+ nn_model([first(u)], nn_params)
numer = (1 + e * cos(χ))^2
denom = M * (p^(3 / 2))
χ̇ = (numer / denom) * y[1]
ϕ̇ = (numer / denom) * y[2]
return [χ̇, ϕ̇]
endLet us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
axislegend(
ax,
[[l1, s1], [l2, s2]],
["Waveform Data", "Waveform Neural Net (Untrained)"];
position=:lb,
)
fig
endSetting Up for Training the Neural Network
Next, we define the objective (loss) function to be minimized when training the neural differential equations.
const mseloss = MSELoss()
function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
return mseloss(pred_waveform, waveform)
endWarmup the loss function
loss(params)0.0007221154735791587Now let us define a callback function to store the loss over time
const losses = Float64[]
function callback(θ, l)
push!(losses, l)
@printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
return false
endTraining the Neural Network
Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
optprob,
BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
callback,
maxiters=1000,
)retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [1.3351865163701371e-5; 0.0001150133684859521; -9.139098256119064e-5; -9.221633808910904e-5; -1.7660129742552175e-5; 8.16493411548815e-5; -0.0001393433631164398; -2.4582390324209873e-5; -0.0001596271322337962; 0.00020397389016536577; 0.00017263379413623005; 3.0771716410522506e-5; -7.927737897234726e-5; -5.487880844155728e-5; 0.00012225011596440812; 5.6274035159735905e-5; 7.397718582062198e-5; 9.078478615258412e-5; -4.303144305590022e-5; 6.317099905566065e-5; 7.62219060561332e-5; 0.00011198696301994831; -9.725840936868199e-5; -5.693377170239453e-5; -0.000125325037515799; -8.700957550892886e-5; -5.1365095714552814e-5; -3.5635173844601326e-5; 7.384907075890862e-5; 0.00012852688087143957; 1.4420486422753786e-6; 0.00015116258873591854;;], bias = [1.281146874201935e-17, 9.813355245818672e-18, -9.493614084576486e-17, -1.0021134321195361e-16, -3.0859283135541515e-17, -4.2233967191683655e-17, -5.898257053171821e-17, -2.9246521452006043e-18, -3.990384134305674e-16, 4.710715930197477e-16, 1.2479318533380437e-16, 4.683929061412801e-18, -2.951324258050782e-16, -1.586393305390598e-17, 1.6713989227469493e-16, -5.0757285286293406e-18, 7.950711057351996e-17, 7.59161912114804e-17, -5.617202690808181e-17, 7.290383082990892e-17, 3.0551065376021377e-17, 2.434538622197999e-16, -1.345556289315499e-16, -7.214316990380887e-17, 8.284086742745758e-17, -7.699015180445418e-17, -2.2224817868059377e-17, -4.1248574265745146e-17, 1.2608606731074913e-16, 7.00137129201382e-17, -3.7982719137690902e-19, 1.3463060484856712e-16]), layer_3 = (weight = [-2.5369291578683938e-5 0.0001773187004080698 0.00010164960705951675 -0.00010975825996792983 -0.00011131947674512292 -5.624508120673038e-6 -0.0001126072630351666 2.2019545856868744e-6 1.2364679533498777e-5 1.2647422336896212e-5 -0.00026055998303417505 0.00012760094698096281 -0.0001038138608048036 -8.535679360575322e-5 0.00019820068420886152 8.196683931673603e-5 9.905373455767317e-5 -0.00013110081916165657 9.911505419930343e-7 0.00011903471469922673 -6.564493658443174e-5 -0.00012102479657114467 7.493788591164083e-5 -0.00018714740424839122 0.00010145972639366166 -0.00019731152413369294 -6.291650341081723e-5 9.215803304032091e-5 0.0001025999489873348 -0.00021959200519558832 0.00012352368950887212 0.0001221854662806481; -6.609221077242285e-5 0.00011682717536159226 -0.0001606078460551567 5.756526093163972e-5 7.986200012082813e-5 6.270353128483992e-5 8.826511998476726e-5 -0.00017981624318936265 -0.00012241900753879757 -3.6292715094267435e-5 2.370275046841336e-5 -0.00014455532277652427 -6.779908494509277e-5 6.330005794579699e-5 -0.00012561254175098274 -9.716452973380152e-5 2.591568773184837e-6 -0.00010348689306711331 9.73338953829612e-5 8.626886642586959e-5 -7.586733613579979e-5 0.00013393696076518301 0.00018110759060645907 7.10015572170035e-5 -4.687235036699707e-5 0.00010065534820596217 2.9690376095730627e-5 -4.118095080202814e-5 -0.00019219136652254932 6.843475216582729e-5 0.00017971763531959267 0.00015173853694277093; -6.961740171074816e-6 5.254721439302084e-5 0.00017470025089310958 -0.00022974530544856068 1.1096346708971518e-5 -5.521015443050026e-5 6.511973809426558e-5 -0.0001683451144050664 -0.00014626374719920878 -4.422325470622914e-5 7.478866358031712e-5 6.506378598019088e-5 5.250515935800271e-5 0.0001015079953798264 -0.00020387345536403975 0.00018378677235191557 -1.837524065890212e-5 -0.00015654968116690873 7.259552260834437e-5 5.417117971345533e-7 -0.00014395897120961744 -0.00020609222953465734 -0.00014927065309861071 -1.3111956620967332e-5 6.085477910561893e-5 -5.005826348464553e-5 4.984803420698075e-5 5.476095550283724e-6 -2.904323152292385e-5 -0.00010409546190897435 -0.00011546606653396544 -6.79040047249353e-5; 6.17931368371293e-5 8.554510029095478e-5 -0.0001465335294877405 8.744234989460151e-5 -0.0002312130887101584 0.00015660223331831098 -7.902196332982845e-5 -7.028965731553683e-5 0.00011407220196739317 0.00012975138858487375 9.707148316843664e-5 5.588694008945517e-5 7.366551332704447e-5 -3.875627183408567e-5 -7.486907227833602e-5 -7.14359771616953e-5 0.00014767310537460959 -0.0001312324525595868 -3.7444449418056985e-5 -6.455703217473027e-5 -7.849198965074204e-6 3.053755833131554e-5 -0.0001261794455156894 0.000197300054850689 6.3031129404110225e-6 8.545522766250768e-5 -3.779262945574776e-5 6.063245788862008e-5 2.3578181270908125e-5 1.1164519009222911e-5 5.779608588379892e-5 9.647048179355148e-5; -3.5773981381329525e-5 -5.762972251079225e-6 0.0002122041873912051 8.435209338075516e-5 -0.0001638603757788714 3.242137927317072e-5 -1.278543140925752e-5 -1.6890173811238008e-5 0.00015745055452442982 6.290380724694124e-5 5.349711818725188e-5 -4.855540146364895e-5 -1.3482649597909114e-6 4.5966447486886386e-5 -6.919575071207125e-5 7.601709103199619e-5 -3.392709231831138e-5 -3.750844598527034e-5 -9.38124035206908e-5 -5.443063173571625e-5 -5.0161150739416476e-5 -6.858051028813146e-5 6.634138435346123e-5 9.729769544663806e-5 6.225829565108437e-6 0.0001019324004113397 0.0001323967694583065 5.025951394951538e-6 -0.000198994985567527 -0.00015829016456392453 -0.0001770429121285033 -1.3619231595485543e-5; -5.472536089213759e-5 0.0002881397752979505 0.00010316984240931475 -3.373476523405852e-5 -2.1653450526303656e-5 3.45464262544361e-5 5.437168998911347e-5 -0.00017824050758931037 -0.00020604082113562624 -0.00015564232911413595 0.0001006211617685632 7.945051890133707e-5 -0.00012328167816589623 5.179069318260877e-5 5.008772281927059e-6 -2.3069802987395098e-5 7.308260079740488e-5 -0.00013007630208007048 -3.46425901035607e-5 3.313634930678548e-5 -4.1856694802064936e-5 0.0001158525859544718 2.229691869148098e-5 4.196811402367313e-5 -7.47219734286217e-5 -9.725188159349377e-5 0.000125776868448936 3.3866928210824924e-5 -9.869960798381886e-5 -0.00014414170312660517 6.0758460099220154e-5 -9.763974626491135e-6; -1.293822758450611e-5 -7.139452252154601e-5 1.119252586071225e-5 -9.501197534697261e-5 3.3098757128776524e-5 0.00011522748202699765 -0.0001602701337760165 -5.240656143721568e-5 -0.00011536001299822177 -0.00021496667947932014 -3.116552004205756e-5 -2.0339824615780708e-5 3.460808541425042e-5 -3.07500701382851e-5 -4.4235693778275646e-5 -3.7227487766332866e-5 -0.00019531022887528526 9.156195694962344e-5 -0.00010755287409845 -0.00011611974666445388 -1.4247296050568777e-5 -9.275575909846928e-5 -7.684760537082322e-5 6.928384784990006e-5 9.250211253679363e-5 -5.5225030948317525e-5 -1.7779249603311324e-5 -0.0001728723542982833 -9.738669511713674e-5 1.2559103495777928e-5 0.00012831793385705448 4.793882210657137e-5; 2.9047819085606098e-5 4.660128692754099e-5 -3.909844522590246e-5 -7.739054413272496e-5 -3.526801370193652e-5 -4.796212593086486e-5 -6.650281787954738e-6 -6.835172385203122e-5 -2.2527260765712916e-5 -6.37849578964987e-7 -0.0001479529556980593 -7.080633363679454e-5 0.00015666550175061025 7.617615196628088e-5 3.41490966513056e-5 2.704375478620696e-5 2.767993541616062e-5 0.00011986033209373747 4.465739843559045e-5 -6.278573269250284e-5 -9.072163370896549e-5 -1.6716684653004396e-5 -0.00012049443753662018 1.2998463909365465e-6 8.628283544217713e-5 1.705717736163682e-5 5.53295293194946e-5 3.661784725969253e-5 0.00019448424140696923 2.3861929059029597e-6 -0.00022672318587305424 0.00010548435679643869; -2.932964121724391e-5 -6.642512340192307e-5 0.00020103379615023597 0.0001861921813935026 0.00018992452937085721 3.1669537241337854e-5 -8.732612110976212e-6 -1.3207661647228006e-5 -9.610435478754448e-6 -3.467777747291309e-5 0.00020607986193055188 0.00010024768966390308 3.9159560804488496e-5 1.4555207092725545e-5 0.0002203872031196882 -0.00011287674311175642 2.808893162307091e-6 -5.9691261009780136e-5 4.428449068940655e-5 -5.338442093203545e-5 5.3602136579808606e-5 2.865955906884715e-5 -0.00014901430000178748 3.5730503833649237e-6 -1.605227003608198e-5 -2.2699619473191714e-5 7.671591123938905e-5 -8.721191584923775e-5 -8.492078955611106e-5 -5.110924353800616e-5 9.900866684178459e-5 -1.8652988325701762e-5; 8.755796959613898e-5 9.172901416187613e-5 -0.00016132385128420334 0.00018059695655703057 -7.917523391787446e-5 -7.59246270941619e-5 -0.0001847550686985304 -4.635816942444555e-5 3.719391468204059e-5 -3.5710991437179474e-5 0.00014863363110217696 -3.847404354553621e-6 5.863093763347228e-5 -2.736969356609528e-5 -2.0845376934255748e-5 3.7119993298004014e-6 3.1767847463390576e-5 -0.00013535335376635682 0.00012428217967263872 0.00011146659896932057 -0.00015064265819509067 9.671730156279609e-5 -5.667149009658513e-5 0.00010459405682505556 0.00015361511548272775 2.3375657906774204e-5 -7.074225352388359e-5 -5.2809604607853294e-5 -1.458611798133257e-5 -2.367908775758027e-5 -0.00011707720325338207 8.046594917152219e-6; -2.652592787565231e-5 9.505572837197803e-6 -0.00015520823047775458 0.00012412828242592828 -0.00028683267568051387 8.234325092705304e-5 -0.00010570249117957272 -0.00014426182873360384 0.00017996991833858265 1.8002721720373482e-5 -0.00012336389603372655 0.00010729091615137133 2.7203719458313377e-5 0.00010018391354894657 -4.260231629526023e-5 7.85226608347537e-6 4.893512765481493e-5 2.8122505557959333e-5 -9.090234140628592e-6 -6.06253253288995e-5 -0.000258946316659447 0.00011298343621380882 0.00021754551004339308 -4.711045961523971e-5 4.276669811876699e-5 2.1821397115828932e-5 4.077220897959011e-5 -6.817837526639251e-6 1.2622649467786609e-5 1.1661340852852086e-5 8.064059680980133e-5 0.00011112517119106012; -2.3750851963891814e-5 -4.405777907725755e-7 -0.00013225642024180254 -3.260135772860951e-5 0.0002495632071223685 -4.553048338847464e-5 0.0001867261745211195 5.515241363911102e-5 0.0001943703246944073 -5.450374729108577e-5 -4.3308922526107825e-5 2.9265855538269397e-6 8.761425132989677e-5 2.829644299355331e-5 -2.1534678948133176e-5 3.0388471801518323e-5 -1.1064482263363126e-5 -0.00010849598757330652 -0.00026512208560587776 6.469294030113079e-5 -0.000128710118500648 0.0001967697317887597 9.423289165725532e-5 -5.8710804209010866e-5 -4.247004463095071e-5 1.9135679482232733e-5 0.0001574037597457508 8.503365106308933e-5 8.334746243791501e-5 -0.00014200241994685394 8.896351218583242e-5 -1.4564795404931654e-5; -0.00012903239882513503 2.1366539960530683e-5 2.8617037180029184e-5 -4.105655431709209e-5 -0.00024778330304622257 2.1837607284341054e-5 -0.0001831194812148634 -3.55023320190102e-5 -8.01005288558793e-5 -0.00010859111977700521 9.48225362044078e-5 4.74302272044439e-6 3.875766663617339e-5 2.1130511533498058e-5 3.2485699029579278e-6 -0.00012271429192702626 9.405320555009706e-5 2.1634575325097208e-5 -1.1223415520421372e-5 0.00011143810218703651 0.00010822802244723627 -1.947069073872308e-5 -6.647597154423174e-6 7.023983301399873e-5 -0.00012501467323823275 -8.23188810209653e-5 -2.2233157204086007e-5 -7.657495631798646e-5 -9.719510023211304e-5 -2.56266601973524e-5 1.7298055474324085e-5 0.00017808724181300173; 0.00013929017735513617 4.753688824111235e-6 2.9994514658887492e-5 -3.092635084890906e-5 9.08191228854351e-5 -1.306623049736451e-5 -3.688932189385712e-5 -4.726904295516488e-5 -1.1557961773245863e-5 9.718525110375407e-5 -0.0001268590261707496 5.4053301607339214e-5 0.0002786523706499307 -6.114272251857041e-5 -0.00016060516496787635 0.00015772137124890848 8.731632409485214e-5 -3.199143826628224e-5 -0.00016640093086884772 4.356955981797843e-5 0.0001421709181494837 -8.629592626954048e-5 0.0001670942016398324 0.00012104239580089608 6.0160594909517746e-5 -2.253697163092025e-5 4.543455146354133e-5 0.00012480428416615925 6.309790148359447e-6 -5.112750874383392e-5 4.1788758277974604e-5 -3.922858592256621e-5; 8.472537577232969e-6 -0.000232088052523134 -5.7275796659331964e-5 -9.59472603581793e-5 0.00013287190048043762 4.305174259837437e-5 -5.630897832550311e-5 -0.00015826090422439262 0.00023594628836046293 -9.879902101337739e-5 -2.3681810868843136e-5 0.0003222494484206765 0.00024624040050451383 0.00010873916433677973 -0.00011158102317228816 -3.7786576665044276e-5 2.0444187232836016e-5 0.00019315620700847906 6.0829562774356884e-5 0.0001753637839855709 6.005141729741136e-5 -6.549502033255737e-5 4.4384716223184934e-5 0.00010341250852701202 -0.00012719228872545016 2.344085404086517e-5 0.00010251690362832499 -7.574411612682894e-5 -3.910723573159624e-5 -0.0001378146775251255 1.2415901967812602e-5 -3.93208942269484e-5; -5.709912901721498e-5 -6.986536781396009e-5 1.6126494037688354e-5 0.00016278339872742331 -0.0002502721475050837 0.0001544455878590242 0.0001217586029358338 0.0001829031968642908 -1.966094740990536e-5 -0.0003218364875376815 -1.2375785709335873e-5 5.855508034955248e-5 -0.0002900994007089171 -2.8147240564346167e-5 -1.7738663055110385e-5 7.498724126862111e-5 -4.65176638411361e-5 -2.7969099656094326e-5 -0.00020282497306680684 -6.861750470334837e-5 -2.519277794845401e-5 2.8279162242782212e-5 5.422646345050309e-6 -0.00023962295407755545 -9.425049965129222e-5 0.00010382172305043957 -0.00011942995505733715 0.0001688351037134373 9.417402153275877e-5 -7.710264643746338e-5 1.1785617707291687e-5 2.99072196046053e-5; 4.5678561736923225e-5 3.7017811127783506e-5 0.00020802072483076722 -4.239751636321672e-5 -9.683133451305293e-5 9.433867468926328e-5 1.8094162232778782e-5 -6.651970964428036e-5 3.503495624481513e-5 8.943302761488337e-5 -9.859269105637218e-5 -4.9763324777968636e-5 0.0002178377378821305 1.848308763312819e-5 7.674551104156445e-5 -5.3872538255903216e-5 5.050420965728698e-5 5.420623228765489e-6 1.132660942842336e-5 6.0999178249245174e-5 -6.075648546373164e-5 2.0571711292918442e-5 7.319670637099874e-5 1.10819662675812e-5 -0.00012070960131512901 0.00011953422451353728 0.00011830503151016064 9.702185140652504e-5 3.823529347739533e-5 4.097311627429866e-5 6.090817057141175e-5 9.709749765353169e-6; 5.8916002695148165e-5 0.00011633409487823832 -5.118690203252184e-5 -8.85399112799755e-5 2.8551870878252686e-5 0.00015436717696994354 -2.001858408636522e-5 -8.381529730704798e-5 -6.527598995904551e-5 -0.00017187011505896671 7.674419089605438e-6 9.39437686842201e-5 -0.00024115945232060056 -8.472612350906063e-5 -0.00019333061024970238 4.192990129504558e-5 -5.429347288185051e-6 -0.00023760524888947554 -1.977134340868327e-5 0.00019113049531861538 -4.7867104489881057e-5 -0.00013867830236866363 -0.00016649730057414388 -6.34768318795732e-6 -4.4760132345414946e-5 0.00013311311524867958 -2.500712957588178e-6 8.908614313804165e-6 -0.0002296407091986053 -0.00012571599470409187 0.00011525053016648348 -0.00016427418993277058; 4.8176118604540584e-5 3.405836704296362e-5 -0.0001544400996983278 7.24344377766136e-5 -8.201326037414401e-5 3.271205294986862e-5 0.00021853269080274754 -8.38922618760918e-5 -0.0001143444655219399 5.358889290425341e-6 9.103583814684766e-5 0.0001297332806523697 0.00021806519597412707 -6.742988390505467e-5 -6.041532598210275e-5 -1.070710872054065e-7 -0.0001639528030984317 0.00015871589365524773 0.00021568773316794904 0.0001231951778998392 -2.054879352107929e-5 1.2383729505747177e-5 -8.084973288823371e-5 -1.574293438642917e-5 -8.08466842619963e-5 0.0001564877626031414 3.1262771295296884e-7 -0.00015950042635558246 8.763984313809978e-5 -0.00019804452077842166 -1.9015748541204163e-6 5.314748576601824e-5; 4.995864781768096e-5 0.00016001410966373387 0.00017502077224298951 -0.00017492310309216749 0.0001392370015159997 0.00023286184130802474 -7.691863885748159e-5 1.1301625985202985e-5 -4.005720544689293e-6 2.9114871889291294e-5 -0.00017888234453176607 0.00011802667557595295 -5.4918162110422544e-5 -5.6914328357829795e-5 -6.137143811111083e-5 -1.415955813763074e-5 -4.88006788295902e-5 -8.460167752216847e-6 4.302066571063316e-5 0.00022925829233236462 8.236406409514831e-5 -3.3309128245037076e-5 3.5461109458539236e-5 2.840736686585358e-5 -6.067848682187279e-5 -1.7952492104539132e-5 -8.8262639260347e-6 0.00011801537601377798 -2.972650128549184e-5 -0.00019221453277760005 -0.00015104240612026263 0.00010837606367521792; -0.00016014561976217297 -0.0001488702283185404 -0.00019473315029781873 9.373314182172957e-6 5.579597649020511e-5 8.380410630795134e-6 7.78598304660562e-5 -8.77852317588833e-5 -4.1336780728582314e-5 -0.00012034474918064291 0.00010207649478079189 -5.323337888139863e-5 -0.00016944959049024251 -7.2702110318488415e-6 2.2047498472690554e-5 8.674618305494444e-5 -1.7606721545146813e-5 -9.726099512218131e-5 -4.614198684603539e-5 7.099877705886395e-5 -2.5543961843052914e-5 -6.475389369907423e-5 0.0001230831812191326 0.0001520663762986231 0.00019220761617734467 4.653469016747918e-5 7.259859279119776e-5 0.00011472841009353654 0.0001886122017222959 -3.6183190846625737e-5 -0.00011720764001975402 6.400022988063651e-5; -4.600372906812294e-5 -5.406001858450289e-5 5.828050557151732e-8 1.757142340908717e-7 4.4173569193428344e-5 3.4377274431378473e-5 -4.1233400253479594e-5 -0.00013112840227355048 -8.18932994947662e-5 -0.00016251627224066857 -0.0002202411151587825 8.438757956222858e-5 -4.6713225884971594e-5 -0.0001782245117569398 0.00015178859482016722 0.00013633898649326913 -0.0002096954877119549 0.00010633752114567731 -0.00020339587629810617 5.900967605025024e-5 6.541511990558004e-5 -5.375645835688204e-5 -2.5578751224254536e-7 -3.1424921900411854e-5 4.103666739734252e-5 7.868140980331003e-5 0.0001621382225207926 6.339155240179292e-5 -9.592482368266588e-5 -3.172882773601029e-5 -3.0708894911602394e-5 3.900113341442239e-6; 4.9992998906227664e-5 9.589350848254464e-5 -0.00011953679895232857 -4.521174232214179e-5 2.4237987968072236e-5 9.083927180180783e-5 6.43673225611531e-5 -2.792769773369373e-5 -0.00015949751504667966 0.00018012393011740947 4.0121099668612804e-5 1.735624173742545e-5 -7.649423459516392e-5 9.592899332782606e-5 -8.150715111260538e-5 -1.4206263035893138e-5 8.596276493771425e-5 5.7812281379294954e-5 1.828081586135404e-5 -5.8001416978375916e-5 0.00012248642779449435 -5.415746672516061e-5 -9.615162379778385e-5 2.067932073695477e-5 -9.021871337172786e-5 -0.00010036737912306078 4.4650419663222184e-5 -1.062721311901395e-5 -9.99254883892792e-5 -5.459241983335654e-5 -0.00018442836553295806 3.135695406381035e-5; 0.00011973215971192287 0.00019169491663159567 0.00014207403530416842 -4.1494384438928045e-5 -3.941637094431617e-5 -6.389880780080478e-5 -0.0001138124591107139 0.00014392605755528373 7.669652934236059e-5 1.367211900468308e-5 8.879499166322424e-5 0.00016315798513641513 2.7549320131576162e-5 -9.379051532778985e-5 -0.00017071884410871447 -5.1207689628434415e-5 0.00011317309492550662 -8.97850275397e-6 0.00012603072338787758 -0.00014454144746014962 3.391628872304091e-5 -3.1174599078448846e-5 -0.0001732090114983376 8.20737166777616e-5 -4.4215709001955415e-5 -4.485320294225744e-5 -0.00011637105688169688 -9.324633918186723e-5 -6.052233054834545e-5 8.996869093788514e-5 -4.704231667165699e-5 3.364444075667871e-5; -4.702956689939289e-5 -7.131987118112049e-5 -0.00014585608337723636 0.00011562906397090331 -7.834861358346784e-5 0.00017240534757080229 7.933392439088733e-5 -2.31402388227716e-5 -1.8958911509895653e-5 2.8677157702006588e-5 -0.00018293303915479446 -4.5370513958141936e-5 -7.373295707579093e-5 1.2730193106884424e-5 -6.669269683254729e-5 -0.00019333975401891295 -4.762125868650609e-5 7.675985793399619e-5 -0.00015772549855241666 -0.0001340184932406782 -2.21325550729868e-5 -0.0001412683301670128 7.026148918600014e-5 -1.3873782022661796e-5 2.2530366863862588e-5 -0.0001045994028164737 -5.766660947769866e-5 -0.0001072619376419878 -0.00016180357093175396 -1.408833818925135e-5 4.9776281543789014e-5 -3.7327361183189365e-5; -5.7348748676026184e-5 -0.00013489543562873505 0.00012394787699331592 9.500309169407057e-5 0.00018210971392573414 3.317419031400401e-5 -6.090413808825109e-5 -0.0001046934667171577 0.0001272850095049659 -0.00010706318790446148 -1.00047664161268e-5 -0.00012390172560817044 9.85925107600466e-5 3.917338471392561e-5 0.0001883471304538905 5.831326415527994e-5 7.836032818186337e-5 -1.74545712140568e-5 0.00010988363282404371 5.1603888998491284e-5 -4.010465637083877e-5 -5.604761827971979e-5 1.3725287310352972e-5 1.4916184220331768e-5 -8.015329149903184e-5 1.4174976961204035e-5 -4.69441329257242e-5 -5.19069726648877e-5 -1.8631515174072985e-6 3.574189030795853e-5 -5.400862574578618e-5 0.00018598469977613646; 0.0001736524390291829 0.0001672975885451205 9.292057169883848e-6 -6.033539838950074e-5 0.00016490384214580398 0.00010777525667822122 -1.403569703666768e-5 -6.521978147770882e-5 0.00013607277278810274 -9.796695936315778e-6 0.00014712931797861443 -4.164169885893967e-5 -0.00019780737429369094 -3.119667607802646e-5 -1.5232804885959636e-5 1.3760122696338355e-5 -2.7620048145538337e-5 0.00018195478688016618 3.427545007457314e-5 -9.311553376005743e-5 -2.2506568798919455e-5 -6.0711805504756594e-5 1.7036696503279243e-5 4.8982139473308465e-5 0.00010302804358092913 -0.00015498015645403312 -0.00011197066132422275 0.00012610220792725992 -3.818471313548102e-5 -6.0520433267589615e-5 -0.00011016012658342372 -5.751348371246283e-5; -0.00011108716649007675 1.3867308323346752e-5 2.348485393189312e-6 -1.528399709537763e-6 -1.1993939781430327e-5 -0.00023308495224457745 0.00015875972235105643 1.0843143852500612e-5 -1.0985019118855256e-5 0.0001808532695099348 -1.7799192621643852e-5 -0.00010580913590504467 5.2897012890274106e-5 0.00018547111608083718 -8.287793684987719e-5 1.9855361413167665e-5 8.039152521402976e-5 -8.136630572382843e-5 -2.6690711140442988e-5 4.409581447132306e-5 -9.151669587711397e-5 -5.4009552743560074e-5 0.00014298832752272024 6.461426596385789e-5 -6.386219252137821e-5 -0.000133850742033725 -1.8112700902317866e-5 -0.00017004675396722416 -8.768017801460409e-5 3.423720477028408e-5 6.164946603118747e-5 0.00011232305473411222; -0.00018050224908504691 -0.0001832259019545144 0.00018206336089066817 0.00027969365900504993 -2.2705933323531104e-5 -0.00010263485611060918 6.90671861696171e-5 4.85032053723179e-5 -5.733491498971361e-5 0.00013591824817126984 -0.00010053747401683742 -5.645115898799235e-5 8.918069787911799e-5 0.00010246547598514548 5.393368182127725e-5 -0.0001728549120423214 9.20931036410084e-5 7.558496717253362e-5 8.605015890795623e-5 9.400639365839735e-6 -4.046634520574208e-6 -2.231771005311536e-5 7.924084483535579e-5 9.33239410108125e-5 -2.584267322957646e-5 -0.0001041266311483303 6.0019931800253345e-6 -1.5171198091347263e-5 -8.020251126138735e-5 -5.6830374622891386e-5 1.7113179902727196e-5 4.9579868838247795e-5; 8.557334802262839e-5 0.00014599907006285968 -7.766190976453888e-5 9.072790560694465e-5 7.175336502641788e-5 -2.579414360425316e-5 -0.00010822005135324192 -5.343151850846792e-7 4.2756534802349285e-5 -4.7035335015982136e-7 0.00018349742020162967 0.00011712881239746093 -1.8741717959892536e-5 -7.756174893202009e-5 -0.00010918924528723763 0.00015510009978115806 -0.0001386389383891581 3.434670286751415e-5 -0.00015270791920683985 3.958159066187433e-5 -1.0074951952810762e-5 2.8197900227130397e-5 0.00013841851175911744 -8.800182406175639e-5 8.057003576922955e-5 -5.472419508933003e-5 -5.9149494355914344e-5 6.462896560933756e-5 -4.623500064681121e-5 -4.234488808051918e-5 -6.042983734132364e-5 8.713436107299757e-5; 6.556408045651321e-6 0.00013152420923017872 5.767233722747714e-5 0.00013512292696998848 -0.00012858449428437084 0.0002558958052601441 0.00016779709893182282 -1.6801832878047814e-5 5.6190595735291145e-5 -0.00013812602502858006 -1.632580517015512e-5 4.654475322480449e-5 2.8755666228454906e-5 -5.176112483001363e-5 2.472003536807617e-6 2.445398730072596e-5 -0.00017299123521058156 3.800041793750089e-5 -5.572550673365539e-5 0.00021582445823026687 4.5844435028477404e-5 2.949593670803161e-5 2.5255841485549093e-5 4.418075094020857e-5 -2.679546252454872e-5 0.00012738129351658292 -0.0002187027008561605 2.784109836523992e-5 -0.00014748731575070216 5.718300725004323e-5 2.779064141817609e-5 0.00012125166302450059; 5.4204743858152413e-5 -5.378792579661984e-5 -0.0002239076110048138 0.00019462553805411683 -8.052034340593458e-5 4.609131988460017e-5 8.603887013365825e-5 -4.3862882964587263e-5 7.103543807958796e-5 -1.6769543296799552e-5 7.58702358568446e-6 0.00021521574140266115 -6.979154372615004e-5 -9.655631388099256e-5 -0.0001471331032438097 6.60027719877292e-7 -4.4686179396503084e-5 1.019422066079797e-5 7.264187493334157e-5 7.973050574279105e-5 0.00021705602025818793 -0.0001424691998612038 7.436873616157956e-5 -0.00012880768722183384 1.859573617608918e-5 -1.1461683104551932e-5 -0.00011955288741532118 -0.00015491510371006483 -0.0001391532112823729 -0.00011496331521569555 1.1449944534122753e-5 0.00011345102589318038], bias = [-9.559509660916014e-10, 5.072709417599869e-10, -3.6550478873940395e-9, 1.797588244921871e-9, -4.1154691094828903e-10, -8.063899652401916e-10, -4.623572728584265e-9, 1.0627096118871139e-10, 2.63793767067262e-9, 1.9713672630233716e-11, 4.455279602379686e-10, 1.9063883106573482e-9, -2.225413542312151e-9, 3.1010467443216156e-9, 2.0878205185708496e-9, -2.9834066572843145e-9, 3.7159860141382096e-9, -4.7348063187579195e-9, 1.4429922959170848e-9, 1.4971951772114186e-9, -2.788422676085933e-12, -2.229200125001887e-9, -8.989402066560683e-10, 4.127486975912763e-10, -5.045562968074985e-9, 1.895379960708922e-9, 1.000357087453982e-9, -5.377410414791939e-10, 1.0752876470288586e-9, 1.1161124369653652e-9, 2.7624834164453142e-9, -1.1912992174765098e-9]), layer_4 = (weight = [-0.0006918974391900458 -0.0008177174401687461 -0.000636989189519575 -0.0006433805649082473 -0.0006557724903089806 -0.0005488227478956973 -0.0006688147615888573 -0.000541044821401495 -0.0006205778354392366 -0.0006261575477209694 -0.000827480365171295 -0.000748876055177267 -0.000684392769353398 -0.0006667961182619614 -0.0006360111108002784 -0.0005757528939895349 -0.0006863834467759504 -0.0007198284271768007 -0.0006820316238142925 -0.0006668108889852215 -0.000945301697855303 -0.0007181306157584802 -0.0006466967062043323 -0.0007332734932615828 -0.0007103647343633755 -0.0007949340662056953 -0.0007175189924100777 -0.0006299760802530885 -0.0008292708971894726 -0.0008053291746736288 -0.0005216583739639397 -0.0008075736359676379; 0.0003069327687931586 0.000232522620832939 0.0001969042436877336 0.0001838048654036996 0.0002880846096700942 0.00018754115137455986 0.00023505441439156463 0.0001360165509565721 0.0003382471312032262 0.000187547058364157 0.00020767313599779288 0.00021209252875137899 0.00028276958313592327 0.0003367632963388309 0.00024534759541425635 0.00012951422962125442 0.00032732570858978223 0.0002799493101120675 0.00021280058184297833 0.0002458196050175752 0.00021507182071169536 0.0004213790047506592 0.00032685192596003794 0.00017120540914065442 0.0002805192654931227 0.00029164755516949194 7.495049401539739e-5 9.693780271137333e-5 0.0001199852286484975 0.00017995164067108877 0.00032646466089365317 0.00020616943030211127], bias = [-0.0006856673576735526, 0.00021150311243176145]))Visualizing the Results
Let us now plot the loss over time
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")
lines!(ax, losses; linewidth=4, alpha=0.75)
scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)
fig
endFinally let us visualize the results
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
s3 = scatter!(
ax,
tsteps,
waveform_nn_trained;
marker=:circle,
alpha=0.5,
strokewidth=2,
markersize=12,
)
axislegend(
ax,
[[l1, s1], [l2, s2], [l3, s3]],
["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
position=:lb,
)
fig
endAppendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 128 × AMD EPYC 7502 32-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
JULIA_CPU_THREADS = 16
JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 16
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = LiterateThis page was generated using Literate.jl.