Training a Neural ODE to Model Gravitational Waveforms
This code is adapted from Astroinformatics/ScientificMachineLearning
The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl
Package Imports
using Lux,
ComponentArrays,
LineSearches,
OrdinaryDiffEqLowOrderRK,
Optimization,
OptimizationOptimJL,
Printf,
Random,
SciMLSensitivity
using CairoMakie
Precompiling OrdinaryDiffEqLowOrderRK...
4000.3 ms ✓ OrdinaryDiffEqLowOrderRK
1 dependency successfully precompiled in 4 seconds. 98 already precompiled.
Define some Utility Functions
Tip
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector
function one2two(path, m₁, m₂)
M = m₁ + m₂
r₁ = m₂ / M .* path
r₂ = -m₁ / M .* path
return r₁, r₂
end
one2two (generic function with 1 method)
Next we define a function to perform the change of variables:
@views function soln2orbit(soln, model_params=nothing)
@assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
if size(soln, 1) == 2
χ = soln[1, :]
ϕ = soln[2, :]
@assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
p, M, e = model_params
else
χ = soln[1, :]
ϕ = soln[2, :]
p = soln[3, :]
e = soln[4, :]
end
r = p ./ (1 .+ e .* cos.(χ))
x = r .* cos.(ϕ)
y = r .* sin.(ϕ)
orbit = vcat(x', y')
return orbit
end
soln2orbit (generic function with 2 methods)
This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d_dt(v::AbstractVector, dt)
a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
b = (v[3:end] .- v[1:(end - 2)]) / 2
c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
return [a; b; c] / dt
end
d_dt (generic function with 1 method)
This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d2_dt2(v::AbstractVector, dt)
a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)
Now we define a function to compute the trace-free moment tensor from the orbit
function orbit2tensor(orbit, component, mass=1)
x = orbit[1, :]
y = orbit[2, :]
Ixx = x .^ 2
Iyy = y .^ 2
Ixy = x .* y
trace = Ixx .+ Iyy
if component[1] == 1 && component[2] == 1
tmp = Ixx .- trace ./ 3
elseif component[1] == 2 && component[2] == 2
tmp = Iyy .- trace ./ 3
else
tmp = Ixy
end
return mass .* tmp
end
function h_22_quadrupole_components(dt, orbit, component, mass=1)
mtensor = orbit2tensor(orbit, component, mass)
mtensor_ddot = d2_dt2(mtensor, dt)
return 2 * mtensor_ddot
end
function h_22_quadrupole(dt, orbit, mass=1)
h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
return h11, h12, h22
end
function h_22_strain_one_body(dt::T, orbit) where {T}
h11, h12, h22 = h_22_quadrupole(dt, orbit)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
h11 = h11_1 + h11_2
h12 = h12_1 + h12_2
h22 = h22_1 + h22_2
return h11, h12, h22
end
function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
# compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2
@assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"
h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
@assert mass_ratio ≤ 1 "mass_ratio must be <= 1"
@assert mass_ratio ≥ 0 "mass_ratio must be non-negative"
orbit = soln2orbit(soln, model_params)
if mass_ratio > 0
m₂ = inv(T(1) + mass_ratio)
m₁ = mass_ratio * m₂
orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
return waveform
end
compute_waveform (generic function with 2 methods)
Simulating the True Model
RelativisticOrbitModel
defines system of odes which describes motion of point like particle in schwarzschild background, uses
where,
function RelativisticOrbitModel(u, (p, M, e), t)
χ, ϕ = u
numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
denom = sqrt((p - 2)^2 - 4 * e^2)
χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = numer / (M * (p^(3 / 2)) * denom)
return [χ̇, ϕ̇]
end
mass_ratio = 0.0 # test particle
u0 = Float64[π, 0.0] # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4) # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize) # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e
Let's simulate the true model and plot the results using OrdinaryDiffEq.jl
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)
axislegend(ax, [[l, s]], ["Waveform Data"])
fig
end
Defiing a Neural Network Model
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model
that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but incase you do use them, make sure to mark them as const
.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl
,
const nn = Chain(
Base.Fix1(fast_activation, cos),
Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[2.2991366f-5; -8.583498f-5; 6.5517443f-6; 5.3402124f-5; -1.08666145f-5; -0.00010128457; 5.68523f-5; -1.6638936f-5; 0.000119466524; 3.055531f-5; -5.5245313f-5; 3.9992938f-5; -4.8409245f-5; -0.000107557666; 5.3367596f-5; -1.5480422f-5; 6.2902865f-5; 5.1292627f-5; -4.8426686f-5; -0.00015003336; 0.00024799266; 0.00025323709; 6.2702624f-5; 9.4228344f-5; -7.783826f-5; 7.317188f-5; 3.9030347f-5; -2.020953f-5; 4.308021f-5; -8.3404164f-5; 3.3048757f-8; -9.1582035f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-5.8219568f-5 5.6704623f-5 0.00013909888 5.2980176f-6 3.6338897f-5 -6.181005f-5 -2.1242795f-5 0.00025756564 -9.458619f-5 -4.7086924f-6 -4.391294f-5 5.24156f-5 3.601997f-5 7.158554f-5 -9.0418696f-5 -4.767726f-5 -2.330328f-5 -3.689927f-6 -8.1392456f-5 -6.6408385f-5 -3.9653798f-5 -8.986469f-5 0.00013489745 -2.5962598f-5 -0.00010218255 -7.5238916f-7 -4.230172f-5 6.784941f-5 9.534376f-5 0.00011720901 6.885369f-5 2.3497922f-5; -0.00011960478 -1.3658739f-5 9.387308f-5 0.00011599899 3.7525122f-5 -4.3441894f-5 9.467006f-5 -0.000110901885 0.00012571647 2.0235473f-5 -9.739728f-5 5.748771f-5 4.208505f-5 3.8626f-6 -7.067662f-5 8.336922f-5 -2.801652f-6 -7.1520794f-6 0.000114836126 -2.4606472f-5 3.627661f-5 -3.3389548f-5 0.00010367365 -0.00012335548 9.1391985f-6 7.777666f-5 -3.4138793f-5 6.1247396f-5 -3.9670067f-5 -0.00015745441 -1.601005f-5 -3.4259418f-5; 3.3351796f-5 1.9055076f-5 -9.58331f-6 5.784427f-5 -2.6395472f-5 5.249458f-5 -9.734788f-5 1.2136419f-5 9.967788f-6 5.411607f-5 -5.5019864f-5 3.6620233f-5 -8.708673f-5 -0.00034498586 -4.5442306f-5 2.502808f-5 2.9541503f-5 9.798317f-5 7.271462f-5 0.00015662667 -0.0001535425 8.072804f-5 9.215586f-5 -0.000103857776 -1.5753456f-5 1.4193621f-5 -0.00014303405 -7.267148f-5 5.6753757f-5 -5.0009206f-5 5.080755f-5 6.350586f-5; 0.00020482356 1.0861745f-5 -9.671719f-6 -0.00015005218 0.00020555199 1.207372f-5 -9.4304036f-5 0.0001288902 4.7468267f-5 0.000116898824 5.4066335f-5 0.00013354643 -8.028522f-5 8.322578f-5 0.000103793514 0.0001100739 3.5063067f-6 3.9725797f-5 2.5333697f-5 -1.3375261f-5 -0.00026227834 -1.3072677f-5 -4.7417334f-6 -9.669438f-5 5.417318f-5 4.995471f-5 -1.4230704f-5 -2.6599038f-5 -5.8075286f-5 9.87433f-5 -6.853814f-5 -2.3976869f-5; -0.00017441036 -1.8112985f-5 1.829317f-5 8.0993865f-5 4.7696958f-5 -1.36927465f-5 -1.2364883f-5 0.00014692063 -5.1033396f-5 -7.022112f-5 5.4461256f-5 -0.00013015722 0.00015812117 -0.00011873402 -1.1324742f-5 -3.8534552f-5 0.00015574784 0.00010927489 -3.2545286f-5 -1.386833f-5 4.642336f-5 -0.00021665895 8.192001f-5 -7.7442186f-5 -5.897529f-5 2.9502691f-5 -0.00010380722 0.000114261835 -5.446264f-5 -5.8757487f-6 4.559499f-6 -0.00012538399; -0.00013008293 -6.9964417f-6 -0.00014034704 2.2409691f-5 -0.00011152868 -6.329138f-5 -3.182964f-5 -0.00011595624 -0.00011920466 -0.00011475798 4.716311f-5 8.147214f-5 0.00017930815 -0.00023311313 -6.124976f-5 4.5375136f-6 3.9529394f-5 8.34973f-5 -2.3085977f-6 -0.00019552102 0.000100885045 -0.00025820328 -5.865585f-5 -0.0001088247 0.000266793 -0.00015161952 -0.0002216725 -2.1376787f-5 -5.86515f-5 -7.548875f-5 -4.3899137f-5 1.986642f-6; 0.00014415577 -0.00018482882 -3.6343943f-5 -0.00018647846 0.00019376157 5.0148323f-5 -4.5797537f-5 3.939363f-5 -9.239469f-5 0.0001247317 -1.9093177f-5 -2.3900902f-5 -4.1360327f-5 4.251934f-5 -0.00018311525 -5.102015f-6 -0.00018249115 -0.00012076465 -0.00015128509 7.285819f-5 4.932682f-5 -0.00024349226 -5.0953866f-5 0.00010831132 -3.6324185f-5 -0.00021203268 -9.224854f-6 6.974685f-5 -0.00014340426 4.1426232f-5 -6.626458f-5 -0.00013871829; 0.00010054425 0.00014877101 -3.135679f-5 -9.73993f-6 -0.00017001801 -0.000115089344 0.000121018245 -4.598016f-5 -0.00021076006 5.5392222f-5 -0.000115251736 -1.9660432f-5 -9.079907f-5 3.5645521f-6 -5.306693f-5 0.00011611718 -0.00021768348 2.6270727f-5 -4.2932243f-5 -9.7235534f-5 -0.00012999916 2.1690092f-5 0.0001426421 1.767561f-5 9.095635f-5 -1.12053285f-5 0.00019969138 -0.00012916494 0.00010231083 0.00023624561 6.777519f-5 5.451278f-5; 0.00012464798 0.00013274785 -2.9369932f-5 7.485238f-5 6.512639f-5 -6.680164f-5 -0.00010541664 -0.00016385716 -2.3183457f-5 -7.647101f-5 6.0999497f-5 -9.4676696f-5 5.011235f-5 0.00016150443 0.00014428736 2.9507793f-5 0.00017615866 2.6011585f-5 -3.875828f-5 8.952636f-6 0.0003164984 -9.15453f-5 -7.9515405f-5 -0.00010215666 -6.782238f-5 -0.00014803624 -0.00023025337 -5.9613398f-5 -9.727077f-5 2.4852026f-5 -5.6738518f-5 -0.00012008567; 0.00011566258 5.742309f-5 3.1661562f-5 -8.692327f-6 -8.942991f-5 -6.0008988f-5 2.5776266f-5 0.00019651066 0.00020050172 0.00012090197 -2.9083873f-5 -0.00013580617 -2.3126371f-5 -3.3684573f-6 -9.3274706f-5 0.00018135727 -0.00018096833 -8.093442f-5 8.8149085f-5 -0.0001567792 9.923852f-5 5.664722f-5 9.294452f-5 9.837874f-5 -1.7441722f-5 -0.00012757829 3.2079155f-5 0.0002445041 -3.8906725f-5 -3.3469463f-5 -0.0001000994 -0.00012877729; 3.718797f-6 -2.1687136f-5 -0.0001640504 -8.442801f-6 -4.359968f-5 0.00013304879 -0.00012214926 0.00024771196 0.00020027671 -7.677207f-5 -0.00028258158 0.00010621158 -7.226625f-5 1.4821367f-5 -3.553894f-5 0.00016259399 0.0001259039 -8.335862f-5 -0.0001924937 0.0002723619 -9.3368835f-6 7.999939f-5 3.764387f-5 6.5320295f-5 -3.140368f-5 -6.973392f-6 -1.6847422f-5 -7.8723286f-5 -4.4342538f-5 1.1305947f-5 0.00010358653 4.575263f-6; -5.29492f-6 -5.5287765f-5 8.941942f-5 0.00013675053 2.672507f-5 -0.00018442085 -7.833097f-5 -0.0002055405 1.6782034f-5 -6.404517f-5 0.0001202632 0.00016463811 5.1574076f-5 -0.00018645506 8.4127554f-5 1.5897875f-5 -0.00014694144 8.0897815f-5 0.00015901207 3.6558545f-6 -5.1098625f-5 -0.00019310505 -0.00010694818 2.2090368f-5 0.00023296797 5.040106f-7 6.7863264f-5 0.00010880415 0.00017420988 -1.9578902f-5 1.4893183f-5 -4.159076f-5; -5.064541f-6 -9.506196f-5 -9.319963f-5 -6.800778f-5 5.2210005f-5 -6.045147f-5 0.00013161394 -0.00018538322 -8.858062f-5 3.5096982f-5 -6.947449f-5 9.7227974f-5 -2.6148322f-5 -0.00010634075 -1.5957932f-5 -4.8160382f-5 -4.043f-6 -0.00014464262 -8.099493f-5 -3.204081f-6 -4.7974216f-5 -0.00027614943 2.1574846f-5 8.015198f-5 9.001251f-5 0.0001041434 -8.008616f-5 -0.00011603037 7.1378495f-6 -2.148154f-5 5.7295703f-5 -5.8323483f-5; -4.3500542f-5 0.00011202484 0.00016668237 -5.210714f-5 -0.00019196539 -1.604436f-5 -9.099276f-6 4.1495576f-5 -0.00010006564 -9.684118f-6 -1.4334895f-5 -0.00014128766 0.00022088927 -2.2216791f-7 -0.00013532043 4.8932263f-5 4.7953563f-5 -2.3124203f-5 -0.00023696244 -1.0545481f-5 4.7527377f-5 -5.504552f-5 7.725878f-6 -3.2219598f-6 9.246543f-5 8.885531f-5 2.7889593f-5 -0.0001820069 -4.25137f-5 -0.00023068814 -0.00011189843 -4.129034f-5; -3.2926212f-6 0.00015579262 0.00014197339 -7.219877f-5 -7.603878f-5 -3.8576825f-5 0.0001601683 9.785706f-5 -0.00013651178 8.864133f-5 9.024924f-5 -5.0272392f-5 3.335516f-5 -6.152382f-6 0.00020088599 -0.00019177354 9.965781f-5 0.000112179034 2.2115095f-5 -6.048329f-5 2.5668487f-5 6.6029825f-5 0.0001863165 2.0091322f-6 0.00024631867 8.289086f-5 1.1424067f-5 -0.00011566954 7.8614885f-6 -2.2169652f-5 -0.00012379751 -4.7618807f-5; 3.0237066f-5 5.825039f-5 2.1840582f-5 -1.1841094f-5 -0.0003159574 -5.824842f-5 1.0917654f-5 -4.117181f-5 -4.2723237f-5 -3.153646f-6 7.405224f-5 5.7832818f-5 -4.3124084f-5 -4.462651f-5 0.0001260359 -1.9451794f-5 -2.9074645f-5 -4.8125352f-5 -1.9907864f-5 0.00011573688 -4.3740412f-5 4.018838f-5 -2.8958395f-5 -6.784826f-5 -4.3501328f-5 -0.00014964542 -2.384436f-5 -9.250166f-5 -6.029466f-5 7.941433f-5 1.7571285f-5 4.6672308f-6; 0.00016627018 -6.961212f-5 2.1321137f-5 5.6782053f-5 -9.710608f-5 0.00013552446 5.9032376f-5 -3.2175823f-5 0.00012639814 -5.020873f-6 -0.000109676184 -3.8525555f-5 -0.00013106855 5.2797277f-5 2.03825f-5 -4.043624f-5 3.4771332f-5 0.00010529486 7.376408f-5 -0.00013571346 0.00016161801 -9.236775f-5 8.921748f-6 -8.1940234f-5 6.508354f-6 2.785609f-5 -2.3311295f-5 -1.1418922f-5 -0.00010012313 0.00016973911 -7.106927f-5 0.00015622504; 6.397587f-5 -0.00018406246 0.00011752594 1.8860292f-5 5.8707632f-5 0.00015737921 -0.00014616319 4.132555f-5 0.0001056659 7.281754f-5 2.0111067f-5 -0.00017744533 -1.0770013f-5 -7.955331f-5 8.84615f-6 -5.9130645f-5 -7.57159f-5 -0.00015273236 7.44345f-6 0.00014308827 0.00016839925 -9.50598f-5 5.3912117f-5 -0.0001716648 0.00015608924 0.00010939007 -5.921574f-5 1.271382f-5 0.00012364633 -0.0001591116 -5.168662f-5 4.8619717f-5; 1.5812891f-5 -5.06485f-5 3.8489015f-5 7.7834426f-5 0.00013021064 2.1270242f-5 -1.657151f-5 -0.00013134192 -5.33457f-5 -3.6257923f-5 -0.00020107829 -5.208415f-6 -1.0505939f-5 -6.688823f-5 2.3405535f-5 8.6672146f-5 4.4903056f-5 -4.9080513f-6 0.00011603921 -7.6306496f-5 -0.00012519241 2.6933358f-5 -0.00015224263 -0.00010324471 0.00012737783 -2.6594396f-5 -0.00010656073 -6.220757f-5 -3.638078f-6 3.76934f-5 -6.229116f-5 -8.852289f-5; 0.00024204148 3.0505751f-5 7.594763f-5 -0.00016296154 -4.228169f-5 -0.00013318936 -1.5524804f-5 -1.0349198f-5 -0.00011035737 -0.00011095589 6.15931f-5 2.5214601f-6 1.7030345f-6 -0.00010229237 -3.0282456f-5 6.0268692f-5 2.1672666f-5 3.0093526f-5 4.7935566f-5 -0.00018624174 9.3035655f-5 -6.382835f-5 -3.243255f-5 -0.00024162146 0.00011185546 5.5010943f-5 -3.296486f-5 0.00018913408 7.7177036f-5 4.7840927f-6 5.6577606f-5 -4.5942816f-5; -4.915112f-6 0.00018245366 -9.505738f-5 -0.00012415995 -4.516775f-5 6.449376f-5 -0.00014523264 -0.00020073385 -2.1694057f-6 8.9043664f-5 -3.729595f-5 1.844521f-5 0.0001574745 -8.280848f-5 -0.00017453452 3.5641762f-5 0.00011641847 6.108968f-5 -4.963893f-6 -0.0002414814 0.00019049838 -8.489293f-5 3.5710105f-5 9.548226f-5 -1.4299045f-5 -4.6105542f-5 -0.00015862081 0.0001373321 5.457152f-5 -5.6522073f-5 -3.3330016f-5 1.992423f-5; -0.00014954919 9.905206f-5 -7.010505f-5 0.00012364965 -4.0216323f-6 -2.729231f-5 -0.0001382901 0.0001338079 4.063117f-5 8.265079f-5 5.0322673f-5 -4.9162758f-5 -0.00012095617 -8.859562f-5 3.1220392f-5 -0.0001148445 2.0289395f-5 0.00010139919 0.00016565618 3.9682778f-5 8.0387645f-6 2.0663734f-5 0.00015239582 -0.00023597946 -3.7198093f-5 0.00013539482 8.610832f-5 -0.00010282331 6.4776764f-6 -3.6009708f-6 -1.928913f-5 0.00012705017; 7.059206f-5 0.00011196772 1.6811266f-5 -0.00013967267 -7.119975f-5 1.4272738f-5 1.7340139f-5 -6.3825275f-5 -0.00016999293 3.55723f-5 2.2983437f-5 0.000112556656 2.2287828f-5 -7.967637f-5 0.0001439371 2.8688523f-6 2.3933328f-5 -4.2329953f-5 -0.0002094586 -0.00014548535 6.32754f-5 8.032751f-5 -6.116973f-5 0.000106467516 -0.00014293104 -0.00011962647 -6.7312336f-5 -7.864014f-5 8.808249f-5 -5.000892f-5 0.00014251894 1.2858851f-5; 5.0789986f-5 2.3881916f-5 -1.1325148f-5 9.970558f-5 -0.00015339663 -3.8863778f-5 -0.00021929062 1.3534589f-5 8.505257f-5 -1.6347161f-6 -8.3298855f-6 -5.7275018f-5 -5.3919153f-5 5.8355774f-5 7.129384f-5 -0.00021498227 -4.5146066f-5 -7.958475f-5 -2.8198994f-5 6.169546f-5 -6.0515104f-5 -2.4709729f-5 0.000115471084 -0.00012174118 -7.391147f-5 1.2529966f-5 -5.243711f-5 -0.00010114203 1.7828377f-5 6.041175f-5 4.3625954f-5 9.129619f-5; -7.8915706f-5 -9.8405864f-5 0.00014986846 -0.00010302117 -6.8218185f-5 -0.00017460597 -6.139657f-5 2.3244713f-5 -3.0136f-6 6.156021f-5 0.00016575308 -3.7126494f-5 -0.00013699336 0.00013068513 -0.00019196208 7.111357f-5 3.8579838f-5 4.6295016f-5 9.733917f-5 7.993583f-5 1.3322711f-5 2.0724925f-5 -2.1665219f-5 0.000109700086 -2.2199141f-5 8.19288f-5 -7.731061f-5 7.4508964f-5 -3.0361532f-5 -0.0001927214 -0.00013098991 -0.00012409175; 7.561741f-5 -4.675435f-5 -9.1949834f-5 -3.06877f-5 -2.483484f-5 5.7443318f-5 0.00017178601 -6.405219f-5 -2.0827196f-5 -7.998652f-5 0.00018086698 -6.281998f-5 8.091007f-5 -1.1042142f-5 -3.8015514f-5 -5.756519f-5 0.00012053339 -0.00021711433 4.638347f-7 -5.420498f-5 -0.00011724154 -9.0514615f-5 -2.17583f-5 -8.9916175f-5 0.00019580382 4.9708287f-5 5.1555435f-6 0.0001041672 9.834567f-6 -0.000233346 9.646654f-5 -8.390887f-5; -9.240069f-5 0.00012608037 -0.00010943635 -0.000116679286 2.7861428f-5 0.00010098314 7.006868f-5 -0.00010649628 -0.00014637485 6.485012f-6 -0.00011195566 -7.3924105f-7 7.070484f-5 -0.00013806703 7.974296f-5 7.80397f-6 -5.4491047f-5 -0.00024627862 0.00018920106 4.771857f-5 -0.0002238211 -6.64457f-5 -2.15856f-5 4.219727f-5 -0.00012943766 6.643444f-5 -3.234695f-5 0.00013295692 -0.00012983057 2.974054f-5 -8.014962f-5 -4.4784752f-5; -1.4286879f-5 5.4815282f-5 6.232455f-5 9.705128f-5 -1.7352077f-5 -3.502548f-7 -3.4260636f-5 0.000112093454 4.2067382f-5 -0.00016220403 -0.00010702757 -0.00017336171 -1.4934125f-5 0.00012850718 0.00015867913 0.00016530057 3.6116213f-5 -1.4878923f-5 -0.00015608955 4.7779136f-5 -5.6406723f-5 -0.00015356796 -3.705806f-5 0.0001672441 -1.598048f-5 2.1240887f-5 7.908158f-6 3.6953683f-5 0.00012605585 4.835223f-5 -3.1092775f-5 -0.00012976603; 0.0001149705 -4.2501568f-5 3.083278f-5 8.42892f-5 -3.8461356f-5 -3.8497084f-5 -3.4017348f-6 -0.00015133599 -2.89122f-5 7.0419825f-5 8.728862f-5 6.20877f-6 -6.6108965f-5 -5.6569224f-6 -3.184992f-6 -0.00013108399 -5.4720615f-5 -3.931178f-5 8.382146f-5 5.7291058f-5 5.875278f-5 -0.00010029687 6.250409f-5 0.00010539442 1.6056528f-6 0.00018525262 -8.690197f-6 9.466893f-5 -5.1363102f-5 -5.47471f-5 -0.00014254628 7.9417605f-6; 1.2670863f-5 -6.1032326f-5 -3.5750918f-6 -9.536987f-5 -1.5925035f-6 1.2307509f-5 8.389418f-5 -0.00014150672 8.227648f-5 -0.000104219514 -5.6801695f-5 8.852177f-5 7.68314f-5 -0.000108089596 8.4097956f-5 -0.000112906324 2.7221135f-5 -2.6245347f-5 -2.5399655f-5 0.00012635207 4.360178f-5 3.0051103f-6 -0.00011316774 8.2238854f-5 -4.097577f-5 0.00014231713 2.4253937f-5 -2.6038026f-5 -0.00018511385 -7.4997843f-6 0.00017657303 2.4468374f-5; -8.7731505f-6 -3.527963f-5 -0.00015219612 -6.036556f-6 5.2334337f-5 -8.326705f-5 4.78126f-5 8.154865f-5 -3.5642992f-5 6.317863f-5 -6.164474f-5 -7.304108f-6 0.00010263827 0.00011686935 6.497645f-5 2.605956f-5 2.4777898f-6 0.00014118382 -0.00014548053 -0.00010210573 -0.00018072438 5.657468f-5 2.0525453f-5 -0.00023116409 3.891694f-5 4.7703503f-5 -0.000109635526 -2.8641212f-5 0.00010068247 -0.00020162677 -2.0480791f-5 -4.1051855f-5; -0.000100948884 4.4854707f-5 -2.4547524f-5 -0.00023428276 0.00012745665 -1.4447015f-5 4.434548f-5 -0.00019636385 8.907514f-5 2.0591315f-5 -6.338693f-5 -6.6093453f-6 5.3547123f-5 -5.963573f-5 4.1942385f-5 9.420799f-5 -0.00015050011 0.000110864545 -4.132117f-5 -7.770551f-5 3.5643552f-5 -8.04523f-5 3.86794f-5 -0.00016456163 4.4176733f-5 -7.198278f-5 0.000120135424 4.1712497f-5 7.821983f-6 -8.1743405f-5 -0.00018307782 -0.00011983142], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[2.7501083f-5 8.1449296f-5 -2.4904184f-5 4.133003f-5 7.08009f-5 8.1315564f-5 9.78924f-5 -0.00010852508 4.623655f-5 -0.00012386133 -4.9571576f-5 0.00014878144 -1.8857882f-6 -0.0001391154 -5.9165643f-5 -0.00016309002 3.9849325f-5 -9.6225594f-5 -0.00010389071 5.293993f-6 -1.6439804f-5 -2.3427892f-5 0.0002053601 4.0283714f-5 -9.3371826f-5 -0.00011576774 -7.586191f-5 -0.00021418105 -1.8204253f-5 -1.21492585f-5 -6.4503365f-5 -0.00017605395; -1.8010292f-5 -0.0001370832 0.000104324405 -4.4561453f-5 -1.1193226f-5 -5.8805632f-5 0.00014144993 3.2744454f-5 0.000114442715 6.795314f-5 -3.4635596f-5 -5.6752247f-5 0.0001418459 6.700771f-5 7.073508f-5 0.00016414873 0.00010501503 -0.0001577505 0.00010373721 -7.4540854f-5 5.479496f-6 -0.0001799688 3.616827f-5 5.990478f-5 9.332182f-5 -6.1461f-5 -6.4399675f-5 -0.000140504 1.2711379f-5 -0.00012592625 -2.8141922f-5 -3.3294327f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))
Similar to most DL frameworks, Lux defaults to using Float32
, however, in this case we need Float64
const params = ComponentArray(f64(ps))
const nn_model = StatefulLuxLayer{true}(nn, nothing, st)
StatefulLuxLayer{true}(
Chain(
layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
layer_2 = Dense(1 => 32, cos), # 64 parameters
layer_3 = Dense(32 => 32, cos), # 1_056 parameters
layer_4 = Dense(32 => 2), # 66 parameters
),
) # Total: 1_186 parameters,
# plus 0 states.
Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses
where,
function ODE_model(u, nn_params, t)
χ, ϕ = u
p, M, e = ode_model_params
# In this example we know that `st` is am empty NamedTuple hence we can safely ignore
# it, however, in general, we should use `st` to store the state of the neural network.
y = 1 .+ nn_model([first(u)], nn_params)
numer = (1 + e * cos(χ))^2
denom = M * (p^(3 / 2))
χ̇ = (numer / denom) * y[1]
ϕ̇ = (numer / denom) * y[2]
return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)
Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
axislegend(
ax,
[[l1, s1], [l2, s2]],
["Waveform Data", "Waveform Neural Net (Untrained)"];
position=:lb,
)
fig
end
Setting Up for Training the Neural Network
Next, we define the objective (loss) function to be minimized when training the neural differential equations.
const mseloss = MSELoss()
function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
return mseloss(pred_waveform, waveform)
end
loss (generic function with 1 method)
Warmup the loss function
loss(params)
0.0006856768485697391
Now let us define a callback function to store the loss over time
const losses = Float64[]
function callback(θ, l)
push!(losses, l)
@printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
return false
end
callback (generic function with 1 method)
Training the Neural Network
Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
optprob,
BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
callback,
maxiters=1000,
)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [2.2991365767651946e-5; -8.583498129151555e-5; 6.551744263552762e-6; 5.340212373987155e-5; -1.0866614502444959e-5; -0.000101284567790539; 5.6852299167066985e-5; -1.6638936358465054e-5; 0.00011946652375606651; 3.05553112411783e-5; -5.524531297849531e-5; 3.9992937672653254e-5; -4.8409245209838386e-5; -0.00010755766561471002; 5.336759568302574e-5; -1.548042200738777e-5; 6.290286546570847e-5; 5.129262717668331e-5; -4.842668568011623e-5; -0.00015003336011422554; 0.00024799266248010126; 0.00025323708541624966; 6.270262383616565e-5; 9.422834409628061e-5; -7.783826004018639e-5; 7.317188283186169e-5; 3.9030346670119536e-5; -2.020953070313731e-5; 4.3080210161824766e-5; -8.340416388808205e-5; 3.304875662023242e-8; -9.158203465620653e-5;;], bias = [1.6405519873136773e-17, -4.272748678538238e-17, 9.253988817853697e-18, 8.057406274053889e-17, -2.640246634192689e-18, -3.659242029018894e-17, 9.443448965725443e-18, -4.4156189522402673e-17, 2.3223727308933437e-16, 1.0303618320160742e-17, -2.0495562140531796e-17, -3.1282278352647296e-18, 7.973901789149076e-17, -1.8888176472972938e-16, 6.604401524608706e-17, 2.478097480626662e-18, 2.438743289775378e-17, 5.822291041694235e-17, -5.911711619538149e-17, -1.0999257678422786e-16, 2.6250745625823533e-18, 7.423954450148178e-16, 7.308403482961642e-17, 6.50315919233383e-17, 3.857281253808082e-17, 1.0651458237324152e-16, 7.68038989633434e-17, -4.37295324366588e-18, 4.072687902163377e-17, -1.2272453994680523e-16, 1.616095961885697e-20, -1.1869926619614855e-16]), layer_3 = (weight = [-5.8218608290521175e-5 5.670558242093957e-5 0.00013909983611632546 5.298977368179232e-6 3.633985691670739e-5 -6.180909328428933e-5 -2.124183517331876e-5 0.0002575665996406641 -9.458523051613972e-5 -4.707732631108603e-6 -4.391198014524447e-5 5.241655820249012e-5 3.602092986660419e-5 7.158650190070325e-5 -9.041773661771465e-5 -4.7676298680347974e-5 -2.3302320886037205e-5 -3.688967370512721e-6 -8.139149646773593e-5 -6.640742573688095e-5 -3.965283828473288e-5 -8.986373065307456e-5 0.00013489840715159024 -2.5961637919287687e-5 -0.00010218159219981878 -7.514294278425262e-7 -4.230075934934995e-5 6.785037042539969e-5 9.534472267808748e-5 0.00011720997211490024 6.88546490270127e-5 2.3498881302323787e-5; -0.00011960423064276949 -1.3658187820446044e-5 9.387363177399717e-5 0.00011599954239233912 3.7525672884905765e-5 -4.344134351679671e-5 9.46706110672238e-5 -0.00011090133404916523 0.00012571702034353593 2.0236023990102962e-5 -9.739672743574206e-5 5.7488259405610886e-5 4.20855991952267e-5 3.863150756906374e-6 -7.067607141191556e-5 8.336976832385737e-5 -2.8011011844940182e-6 -7.15152857349316e-6 0.00011483667701856757 -2.4605921064841753e-5 3.627716221404667e-5 -3.338899689237228e-5 0.00010367420116114926 -0.00012335492858522165 9.139749342531368e-6 7.777721264307859e-5 -3.41382422656605e-5 6.124794679369404e-5 -3.966951619343298e-5 -0.00015745386401515944 -1.6009499015970075e-5 -3.425886672896847e-5; 3.335112248151063e-5 1.9054402371515736e-5 -9.583984167145562e-6 5.7843596521527664e-5 -2.6396145615033316e-5 5.249390502995524e-5 -9.734855578444407e-5 1.2135745207890459e-5 9.967114208671606e-6 5.411539584001966e-5 -5.5020537758716184e-5 3.661955903935701e-5 -8.708740372839845e-5 -0.00034498653715872465 -4.544298011041696e-5 2.5027405143011875e-5 2.9540828989232357e-5 9.798249354189738e-5 7.271394403487914e-5 0.00015662599886837397 -0.00015354317022882492 8.072736360901933e-5 9.215518636465579e-5 -0.00010385844964930174 -1.5754129919692943e-5 1.4192946735241693e-5 -0.00014303472292540143 -7.267215281189924e-5 5.675308327027125e-5 -5.000988042611058e-5 5.0806876031676766e-5 6.350518648815014e-5; 0.00020482562662162745 1.0863809745369289e-5 -9.669654776202928e-6 -0.00015005011106315297 0.0002055540518422141 1.2075785104550628e-5 -9.430197120583871e-5 0.0001288922632165713 4.747033211694969e-5 0.00011690088843987223 5.406839967416357e-5 0.0001335484977398519 -8.028315811590487e-5 8.322784391291559e-5 0.00010379557910435604 0.0001100759674381925 3.5083713892002094e-6 3.972786192842545e-5 2.5335761290064927e-5 -1.337319632787637e-5 -0.00026227627118993046 -1.3070611988097857e-5 -4.7396687141533345e-6 -9.669231245699969e-5 5.4175243473747025e-5 4.995677421628101e-5 -1.4228639759508899e-5 -2.6596973714745333e-5 -5.8073221101054685e-5 9.874536602817903e-5 -6.853607361610085e-5 -2.3974804083100926e-5; -0.00017441138592318612 -1.8114006560670538e-5 1.8292149580637976e-5 8.099284346911224e-5 4.769593684698462e-5 -1.369376772260416e-5 -1.236590454852331e-5 0.00014691961147121166 -5.103441693705963e-5 -7.022213943771049e-5 5.446023456982442e-5 -0.00013015824046221504 0.00015812014786291019 -0.00011873503793964706 -1.132576366132493e-5 -3.853557322060678e-5 0.00015574681780065546 0.00010927387225904594 -3.2546306864851275e-5 -1.3869351131745583e-5 4.6422338674303036e-5 -0.0002166599759575978 8.191898556190677e-5 -7.744320722764993e-5 -5.89763102098627e-5 2.950166996100167e-5 -0.00010380824354712796 0.00011426081360780653 -5.4463659434140936e-5 -5.876769900231148e-6 4.558477631625757e-6 -0.00012538500854049526; -0.00013008842966753792 -7.001939646729555e-6 -0.00014035253576237595 2.2404193206444917e-5 -0.00011153417750263461 -6.329687772318654e-5 -3.183513859637332e-5 -0.00011596173595405961 -0.00011921015999046306 -0.00011476348032245142 4.7157612190336885e-5 8.146664453436077e-5 0.000179302648522412 -0.00023311862338101192 -6.12552585685751e-5 4.532015683636153e-6 3.952389574026566e-5 8.349180547286396e-5 -2.3140956350969077e-6 -0.0001955265152524001 0.00010087954700933627 -0.0002582087736831946 -5.8661346937307496e-5 -0.00010883019882254483 0.0002667875114106506 -0.00015162501682295342 -0.00022167799493997913 -2.1382285066082223e-5 -5.865699955263671e-5 -7.549424941186382e-5 -4.390463532998121e-5 1.9811440701569768e-6; 0.00014415126899549327 -0.00018483332276516933 -3.634844286424586e-5 -0.00018648295697929403 0.000193757074538744 5.014382297233762e-5 -4.580203638468882e-5 3.9389131797370725e-5 -9.239918616073466e-5 0.00012472719792568706 -1.909767668692899e-5 -2.390540210470808e-5 -4.136482637548961e-5 4.251483953269474e-5 -0.00018311974743554434 -5.10651496441786e-6 -0.00018249564489522855 -0.0001207691489806036 -0.00015128958655694644 7.285368739689712e-5 4.932232006404679e-5 -0.00024349676421542141 -5.09583656646861e-5 0.00010830681678200957 -3.6328685001335746e-5 -0.00021203718241011427 -9.229353582475159e-6 6.974234967784531e-5 -0.0001434087640343814 4.142173240254234e-5 -6.626908323668428e-5 -0.00013872278539636364; 0.00010054430007363944 0.00014877106426644402 -3.1356737019845317e-5 -9.739876043053427e-6 -0.00017001796092011692 -0.00011508929034101518 0.00012101829910313049 -4.5980105789962263e-5 -0.00021076000572287228 5.53922755330549e-5 -0.00011525168243900602 -1.966037867753959e-5 -9.079901516309667e-5 3.5646058109393323e-6 -5.306687759224054e-5 0.00011611723588208944 -0.00021768342863873308 2.627078087022334e-5 -4.2932188955486825e-5 -9.723548006371357e-5 -0.000129999102648467 2.1690145365541572e-5 0.00014264214681822146 1.7675663036757896e-5 9.095640692754914e-5 -1.1205274824928154e-5 0.00019969143162376518 -0.00012916488500416937 0.00010231088075449031 0.0002362456680544399 6.777524234854891e-5 5.451283417420723e-5; 0.0001246470590524756 0.00013274693155497595 -2.937085373898437e-5 7.485146015485898e-5 6.512546933325077e-5 -6.68025579250008e-5 -0.00010541756383074286 -0.00016385807721573786 -2.3184378009061058e-5 -7.647193452041503e-5 6.0998575278317246e-5 -9.467761738389618e-5 5.011142893297733e-5 0.00016150350965861646 0.00014428644026000205 2.9506871939046093e-5 0.00017615773940295165 2.601066334707806e-5 -3.87592019285155e-5 8.951714149848969e-6 0.0003164974721551505 -9.154621984971562e-5 -7.951632615342555e-5 -0.00010215757829766388 -6.782330201869584e-5 -0.00014803716221987976 -0.0002302542900031049 -5.961431925581692e-5 -9.727169272420808e-5 2.4851104896043362e-5 -5.673943924532208e-5 -0.00012008659287286608; 0.00011566375134130298 5.742426175124971e-5 3.166273337406414e-5 -8.691156370374698e-6 -8.942873727214217e-5 -6.000781694460744e-5 2.577743681463599e-5 0.00019651183493984232 0.0002005028868661145 0.00012090313931544586 -2.908270151969784e-5 -0.00013580499988897454 -2.312520017589688e-5 -3.3672863428820257e-6 -9.327353517441809e-5 0.00018135844539934136 -0.00018096715983851656 -8.09332491262888e-5 8.81502560441311e-5 -0.0001567780350233577 9.923968856303767e-5 5.6648390011039486e-5 9.294569418848653e-5 9.837991047966263e-5 -1.7440550880608307e-5 -0.0001275771141882137 3.208032604739322e-5 0.00024450527372387533 -3.8905553527043606e-5 -3.346829224760601e-5 -0.00010009823062141114 -0.0001287761192434562; 3.71979387140796e-6 -2.1686138903041394e-5 -0.0001640494082361846 -8.441804207412018e-6 -4.359868388333498e-5 0.00013304978357567916 -0.00012214825972974543 0.00024771295294571515 0.00020027771107792257 -7.677107340483161e-5 -0.000282580578537325 0.0001062125739868339 -7.226525475330864e-5 1.4822363586553381e-5 -3.5537941480886614e-5 0.0001625949881732077 0.00012590489506195967 -8.335761948315649e-5 -0.00019249269815687305 0.0002723629078123259 -9.335886611987716e-6 8.000038370893915e-5 3.764486620142792e-5 6.53212920174519e-5 -3.1402683712675765e-5 -6.972395008650679e-6 -1.6846424737611782e-5 -7.872228878609985e-5 -4.4341540965849625e-5 1.1306943738000577e-5 0.00010358752489495887 4.576260029844118e-6; -5.29405583997987e-6 -5.528690049155139e-5 8.942028766418557e-5 0.0001367513960265987 2.672593477277298e-5 -0.00018441998140966126 -7.833010927620682e-5 -0.00020553964164187459 1.678289849755173e-5 -6.404430492162048e-5 0.0001202640615209438 0.00016463897740854585 5.157493983083091e-5 -0.000186454193639435 8.412841820777066e-5 1.589873868019529e-5 -0.0001469405778645108 8.089867883049561e-5 0.00015901293049486967 3.6567185323655967e-6 -5.109776061924274e-5 -0.00019310418606816967 -0.00010694731641174653 2.2091231627058858e-5 0.0002329688341395422 5.048746478180121e-7 6.786412818277895e-5 0.00010880501763117239 0.00017421074703256295 -1.9578037567185017e-5 1.4894047167987862e-5 -4.1589895730936144e-5; -5.06817689218636e-6 -9.506559392737072e-5 -9.320326892028348e-5 -6.801141528247309e-5 5.220636928735524e-5 -6.045510774544876e-5 0.00013161030262874102 -0.00018538685792426337 -8.858425819613929e-5 3.509334608262725e-5 -6.947812464621279e-5 9.722433794687902e-5 -2.6151958013162734e-5 -0.00010634438324320411 -1.5961568282893045e-5 -4.816401808473975e-5 -4.0466360811297445e-6 -0.00014464625376111304 -8.099856306633424e-5 -3.207717031323994e-6 -4.797785179517173e-5 -0.00027615306128109425 2.157120986832516e-5 8.014834209045726e-5 9.000887264430521e-5 0.0001041397645853265 -8.008979596032369e-5 -0.00011603400882866907 7.13421338825101e-6 -2.148517698931404e-5 5.7292067208813736e-5 -5.832711933596717e-5; -4.350315897229809e-5 0.00011202222491645159 0.00016667974965909701 -5.210975826687662e-5 -0.00019196800452394195 -1.6046976393245493e-5 -9.101892951573081e-6 4.1492958976023335e-5 -0.0001000682579024366 -9.686734424603117e-6 -1.4337511979672118e-5 -0.00014129027580113738 0.00022088665251883156 -2.2478465156760712e-7 -0.00013532304468301607 4.8929646131882846e-5 4.795094615926579e-5 -2.3126819664892155e-5 -0.00023696505844330006 -1.0548097757245019e-5 4.752476058004313e-5 -5.5048135559418026e-5 7.723261537624823e-6 -3.224576526028337e-6 9.246281375906223e-5 8.885269459367463e-5 2.78869764451623e-5 -0.0001820095177527307 -4.2516317203106444e-5 -0.0002306907528100269 -0.00011190104590698871 -4.1292955879746074e-5; -3.2894739979884587e-6 0.0001557957624208367 0.00014197653612422567 -7.219561999596201e-5 -7.603563031812399e-5 -3.857367815732334e-5 0.00016017145057043472 9.786020777023182e-5 -0.00013650863153095142 8.864447985610642e-5 9.025238645330503e-5 -5.026924523922965e-5 3.335830873760901e-5 -6.149234795946168e-6 0.00020088913546709793 -0.00019177038817365865 9.966095634803515e-5 0.00011218218091011523 2.2118242068193314e-5 -6.0480142256055795e-5 2.5671634229980548e-5 6.603297182652844e-5 0.00018631964152513412 2.012279392088612e-6 0.00024632181552073116 8.289400679126577e-5 1.1427214612871251e-5 -0.00011566639628659231 7.864635645040274e-6 -2.2166504695689928e-5 -0.00012379436126143858 -4.761565984159823e-5; 3.023516117147972e-5 5.824848383392972e-5 2.1838676980881286e-5 -1.1842999560281007e-5 -0.00031595931679865034 -5.825032596864723e-5 1.0915749294986417e-5 -4.117371529380471e-5 -4.272514227001661e-5 -3.155551188934937e-6 7.405033218285564e-5 5.783091298846932e-5 -4.312598932690468e-5 -4.4628416366566876e-5 0.0001260339934339981 -1.9453699427937705e-5 -2.9076549912382927e-5 -4.812725703620815e-5 -1.990976918411923e-5 0.00011573497729454039 -4.374231750647105e-5 4.018647566751243e-5 -2.8960300118591732e-5 -6.785016623714079e-5 -4.350323317728254e-5 -0.0001496473257468763 -2.3846265980206168e-5 -9.250356263171553e-5 -6.029656535636214e-5 7.941242190620669e-5 1.7569379702511482e-5 4.665325628346902e-6; 0.00016627142298048444 -6.961088098805591e-5 2.132237653668605e-5 5.678329301545385e-5 -9.710484033903009e-5 0.0001355256999285302 5.9033615738283564e-5 -3.217458385378963e-5 0.00012639937898156693 -5.019633195930442e-6 -0.00010967494478436208 -3.852431570166683e-5 -0.0001310673079079575 5.279851660261497e-5 2.038373980567778e-5 -4.0435000361044895e-5 3.477257172007989e-5 0.00010529609692994502 7.376531987929725e-5 -0.00013571222104119944 0.00016161924844152706 -9.236651118601362e-5 8.922987884855267e-6 -8.193899435909722e-5 6.509593671491348e-6 2.785732879956605e-5 -2.3310055499765613e-5 -1.1417682859333304e-5 -0.00010012188892332809 0.00016974035218836559 -7.106802974348786e-5 0.00015622627955408532; 6.397610954523733e-5 -0.0001840622241208084 0.00011752617533134947 1.886052845735267e-5 5.87078687611753e-5 0.0001573794473596161 -0.00014616295416557533 4.132578791983143e-5 0.00010566613886554335 7.281777320245924e-5 2.011130377000826e-5 -0.00017744508991667361 -1.0769776649922786e-5 -7.955307560691049e-5 8.846386976856046e-6 -5.913040845838657e-5 -7.57156628016134e-5 -0.00015273212525711695 7.443686940732226e-6 0.00014308850617895377 0.00016839948179481732 -9.505956009447416e-5 5.3912354219264036e-5 -0.00017166455987093388 0.0001560894782822826 0.00010939030866653337 -5.921550442066335e-5 1.2714056851891297e-5 0.00012364656535724815 -0.00015911136475109828 -5.168638349927311e-5 4.861995361794373e-5; 1.5810655108752215e-5 -5.065073752590744e-5 3.848677909884707e-5 7.78321892856465e-5 0.00013020840300198753 2.1268005307171647e-5 -1.6573746384136712e-5 -0.0001313441560934472 -5.334793501348018e-5 -3.6260159381682204e-5 -0.0002010805281908405 -5.210651456918695e-6 -1.0508175241893206e-5 -6.689046391799096e-5 2.3403298778094668e-5 8.666990937383965e-5 4.490081949809959e-5 -4.910287648425975e-6 0.00011603697670109843 -7.63087326642137e-5 -0.00012519464768513154 2.6931121405941937e-5 -0.00015224486820913463 -0.00010324694709667182 0.00012737559621640098 -2.6596632656189322e-5 -0.0001065629647793413 -6.220980400500687e-5 -3.6403142668585636e-6 3.7691164047677605e-5 -6.229339748203654e-5 -8.852512370609891e-5; 0.000242040904358216 3.050517328963777e-5 7.594705379167594e-5 -0.00016296212130393755 -4.228226613188136e-5 -0.00013318993593267883 -1.552538129610391e-5 -1.0349776185396772e-5 -0.00011035794455958239 -0.00011095646483292763 6.159252074146722e-5 2.5208823801925877e-6 1.7024567471503747e-6 -0.00010229294572192936 -3.0283033320958888e-5 6.0268114556744976e-5 2.1672087986442302e-5 3.009294845403711e-5 4.793498805090853e-5 -0.00018624231894202086 9.303507680605409e-5 -6.382893048858908e-5 -3.243312608967437e-5 -0.00024162203340958614 0.00011185487959599258 5.5010365703563505e-5 -3.2965438786707286e-5 0.00018913350208029186 7.717645779782963e-5 4.7835148864382904e-6 5.657702853502729e-5 -4.5943393398121056e-5; -4.916167308338447e-6 0.00018245260424045427 -9.505843638145102e-5 -0.00012416100809751652 -4.516880664560678e-5 6.449270575139356e-5 -0.0001452336947473753 -0.0002007349031551807 -2.1704607839685e-6 8.90426089248556e-5 -3.729700631048164e-5 1.8444154084490984e-5 0.0001574734413956233 -8.280953636095198e-5 -0.00017453557677036725 3.564070710161971e-5 0.00011641741719394385 6.108862715372234e-5 -4.964948056667415e-6 -0.00024148244858188453 0.00019049732498425568 -8.489398180207569e-5 3.5709050171491005e-5 9.548120568105588e-5 -1.4300099901091267e-5 -4.610659754653947e-5 -0.00015862186421109983 0.00013733104413439794 5.457046416303468e-5 -5.6523127677435635e-5 -3.333107098988926e-5 1.9923174368516454e-5; -0.0001495482347959806 9.905301060684139e-5 -7.010409815080121e-5 0.0001236506004065115 -4.020678326520233e-6 -2.7291355828819524e-5 -0.00013828914149740466 0.0001338088558699783 4.063212411750943e-5 8.26517434038067e-5 5.032362690002906e-5 -4.916180393734653e-5 -0.00012095521294355123 -8.859466385871186e-5 3.122134596964599e-5 -0.00011484354679083114 2.0290349246754665e-5 0.00010140014722168707 0.00016565713388303572 3.9683732146328535e-5 8.03971845782438e-6 2.066468817103882e-5 0.00015239677389076203 -0.00023597850583570004 -3.7197139201050924e-5 0.00013539577132946063 8.610927118610886e-5 -0.00010282235513982963 6.478630374152514e-6 -3.6000167906952484e-6 -1.9288175346042318e-5 0.000127051121060905; 7.059074423399392e-5 0.00011196640125366786 1.6809948736988837e-5 -0.00013967398981672976 -7.120106593116569e-5 1.427142078559198e-5 1.733882172506055e-5 -6.382659185088614e-5 -0.00016999424397122823 3.5570981798298924e-5 2.2982119934239658e-5 0.00011255533909083028 2.2286511101482378e-5 -7.967768751892154e-5 0.00014393578030525442 2.8675354757413483e-6 2.393201075268692e-5 -4.2331270289069906e-5 -0.0002094599111854779 -0.00014548666504320118 6.327408491294266e-5 8.032619380052678e-5 -6.1171049220648e-5 0.00010646619930288745 -0.00014293235281125766 -0.00011962778799112112 -6.731365274514392e-5 -7.86414601874295e-5 8.808117136017113e-5 -5.001023591954351e-5 0.00014251762340667274 1.2857533964635514e-5; 5.0788028822671384e-5 2.3879958682623148e-5 -1.1327105091959786e-5 9.970362611270632e-5 -0.00015339858487791429 -3.886573518470266e-5 -0.00021929258196691831 1.353263200357982e-5 8.505061145329767e-5 -1.6366731257675382e-6 -8.331842543746843e-6 -5.7276974764726616e-5 -5.392111030018796e-5 5.835381693246521e-5 7.129188474423375e-5 -0.00021498422552908812 -4.514802254347869e-5 -7.958670881421195e-5 -2.820095119900444e-5 6.16934996672764e-5 -6.0517060553690985e-5 -2.471168568954973e-5 0.000115469127415374 -0.00012174313429354183 -7.391342638370254e-5 1.2528008703478182e-5 -5.2439066855776656e-5 -0.00010114398879309217 1.7826420155279947e-5 6.0409791551459703e-5 4.36239972644666e-5 9.129423437330417e-5; -7.891707321923734e-5 -9.840723181205072e-5 0.00014986709026324963 -0.00010302253907375442 -6.821955288123864e-5 -0.0001746073390991993 -6.13979351643139e-5 2.3243345245099098e-5 -3.0149675554590945e-6 6.155884363160912e-5 0.0001657517135439774 -3.712786168271204e-5 -0.0001369947273220537 0.00013068376607898615 -0.00019196345204782142 7.111220508286209e-5 3.857847001690995e-5 4.6293648088430577e-5 9.733780495062045e-5 7.993446376113698e-5 1.332134376337179e-5 2.0723557432056512e-5 -2.1666586357787075e-5 0.00010969871833858877 -2.2200508860440727e-5 8.19274358631455e-5 -7.731197514167518e-5 7.45075961871394e-5 -3.036290002874081e-5 -0.00019272277098443576 -0.0001309912764869311 -0.00012409311569590276; 7.561639311547996e-5 -4.675536958443101e-5 -9.195085341984748e-5 -3.06887200340933e-5 -2.4835859693711302e-5 5.744229882198382e-5 0.00017178499500091886 -6.405320938852019e-5 -2.0828214946089876e-5 -7.998754201166644e-5 0.0001808659576281136 -6.282099686272733e-5 8.090904837531702e-5 -1.104316153855023e-5 -3.8016533398193015e-5 -5.7566209324294346e-5 0.00012053237103914628 -0.00021711534693081186 4.6281560317034033e-7 -5.42059974750743e-5 -0.00011724256229993119 -9.051563440064044e-5 -2.1759319241980368e-5 -8.991719416284983e-5 0.0001958028001181022 4.970726830011067e-5 5.154524370438277e-6 0.00010416618122776641 9.833548102462167e-6 -0.00023334701883781812 9.646551869168534e-5 -8.390988707019082e-5; -9.240373429528555e-5 0.00012607732395955923 -0.0001094393922427034 -0.00011668233156123088 2.7858382684288826e-5 0.00010098009415052746 7.006563436553963e-5 -0.00010649932328996392 -0.0001463778938228419 6.48196664079481e-6 -0.00011195870711852666 -7.422863132790038e-7 7.070179316761965e-5 -0.00013807007445173296 7.973991495638265e-5 7.800924952252238e-6 -5.449409246247039e-5 -0.00024628166675719334 0.0001891980170443381 4.771552602727164e-5 -0.0002238241470945686 -6.644874184305184e-5 -2.1588645092912614e-5 4.2194223799395956e-5 -0.00012944070140982552 6.643139540066374e-5 -3.234999558807369e-5 0.00013295387515646273 -0.0001298336176729108 2.9737495144596863e-5 -8.0152666039422e-5 -4.478779765812019e-5; -1.4286202093772728e-5 5.4815959339191416e-5 6.232522781023286e-5 9.705195469099498e-5 -1.735139932758179e-5 -3.495775105087139e-7 -3.4259959028405875e-5 0.00011209413121936368 4.20680595935989e-5 -0.0001622033517549981 -0.00010702689116338803 -0.00017336103275634852 -1.4933447636268641e-5 0.0001285078548618334 0.00015867981044824636 0.00016530125201928887 3.611689015191424e-5 -1.4878245855344596e-5 -0.00015608886980668227 4.777981364998932e-5 -5.640604587722535e-5 -0.00015356728483209412 -3.70573828321192e-5 0.0001672447767133572 -1.597980310076518e-5 2.1241564064564834e-5 7.908835578872462e-6 3.695436014926515e-5 0.0001260565265339562 4.835290808945273e-5 -3.109209803265988e-5 -0.00012976535737731215; 0.00011497082311546547 -4.2501243232953817e-5 3.0833104456571256e-5 8.428952139223088e-5 -3.8461031299055236e-5 -3.849675988892076e-5 -3.401410258055218e-6 -0.00015133566478613643 -2.891187535130444e-5 7.042014919573555e-5 8.72889476071866e-5 6.209094745310644e-6 -6.610864068495012e-5 -5.656597868045761e-6 -3.184667438807293e-6 -0.00013108366249941169 -5.472028999935752e-5 -3.9311456138937694e-5 8.382178645700314e-5 5.729138217287025e-5 5.8753103867661554e-5 -0.000100296546524601 6.250441662853667e-5 0.00010539474612453983 1.6059773274095235e-6 0.00018525294296562744 -8.68987274731633e-6 9.466925700547292e-5 -5.136277753041951e-5 -5.4746774485073466e-5 -0.00014254595874223732 7.94208505834024e-6; 1.2670794286189367e-5 -6.1032394432165275e-5 -3.5751603256149803e-6 -9.536993761610306e-5 -1.5925720382957473e-6 1.230744023889488e-5 8.389410927595506e-5 -0.00014150679212359784 8.227640923980583e-5 -0.0001042195829988278 -5.680176341154677e-5 8.852169835410945e-5 7.683132994390957e-5 -0.00010808966485381172 8.409788702085548e-5 -0.00011290639245014709 2.722106605763019e-5 -2.6245415346399598e-5 -2.5399723516945454e-5 0.0001263519995071178 4.360171137780017e-5 3.0050417997922178e-6 -0.00011316781033126701 8.223878526298277e-5 -4.097583891887275e-5 0.0001423170657752262 2.4253868696926418e-5 -2.6038094210141106e-5 -0.0001851139198599204 -7.499852789966946e-6 0.00017657295816281735 2.446830571973178e-5; -8.774890127513431e-6 -3.528137015351021e-5 -0.0001521978635788929 -6.038295603190272e-6 5.2332597326623636e-5 -8.326878808261589e-5 4.7810858879842116e-5 8.154690956175e-5 -3.564473147675861e-5 6.317689191959506e-5 -6.16464828705432e-5 -7.305847455115491e-6 0.00010263652736497832 0.00011686760924235298 6.497470828648475e-5 2.605782108034143e-5 2.4760502331849137e-6 0.00014118207697933348 -0.00014548227111632216 -0.00010210747198256663 -0.00018072612213256427 5.657294175069342e-5 2.0523713167080383e-5 -0.0002311658251780069 3.891520034116692e-5 4.770176317137625e-5 -0.00010963726593004797 -2.8642951954368907e-5 0.00010068072813041271 -0.00020162850771853649 -2.0482530751594978e-5 -4.105359469363334e-5; -0.00010095152714357187 4.485206415111384e-5 -2.455016686881409e-5 -0.00023428540765607934 0.0001274540027806305 -1.4449657952438885e-5 4.434283806758761e-5 -0.00019636649261529053 8.907249690665714e-5 2.058867181758491e-5 -6.338957058333049e-5 -6.611988224995666e-6 5.3544479727770496e-5 -5.963837423777591e-5 4.193974206273238e-5 9.420535030940115e-5 -0.00015050275572875624 0.00011086190176129118 -4.132381394456391e-5 -7.770815312695127e-5 3.564090917639642e-5 -8.045494354159137e-5 3.867675883511901e-5 -0.00016456427141392288 4.417409042062309e-5 -7.198542328851287e-5 0.00012013278143413045 4.170985454393135e-5 7.81933982934744e-6 -8.174604766836241e-5 -0.00018308046304455475 -0.00011983406324006343], bias = [9.597275820016463e-10, 5.508600884000342e-10, -6.739648575939672e-10, 2.064677560317983e-9, -1.0212161857275647e-9, -5.497948749797999e-9, -4.499798515071002e-9, 5.3694353407263754e-11, -9.215034165494342e-10, 1.1709853013293928e-9, 9.969103695249072e-10, 8.640616641577613e-10, -3.6360912283475337e-9, -2.6167390661292588e-9, 3.1471608467479754e-9, -1.9051406315962506e-9, 1.2395811649983774e-9, 2.367822704775623e-10, -2.2363248457160964e-9, -5.777662617580919e-10, -1.0551226392295494e-9, 9.539948624435667e-10, -1.3168586177159872e-9, -1.9570121826772853e-9, -1.367547703629242e-9, -1.019106138213937e-9, -3.0452680139073694e-9, 6.772835505904456e-10, 3.2457176721297434e-10, -6.852967735719326e-11, -1.739615671044894e-9, -2.64293348868611e-9]), layer_4 = (weight = [-0.0006421090468990606 -0.0005881608494292285 -0.0006945143249889012 -0.0006282800157326237 -0.0005988092238042243 -0.0005882938381387899 -0.0005717172402583216 -0.0007781352297308512 -0.0006233735813501144 -0.0007934714486144631 -0.000719181702679124 -0.0005208286904723128 -0.0006714955935882372 -0.0008087253579617997 -0.0007287755319422764 -0.0008327000735386744 -0.0006297607884815511 -0.0007658357458621108 -0.0007735007236133582 -0.0006643161512328031 -0.0006860499282772278 -0.0006930380220331149 -0.0004642500055866788 -0.0006293263407870992 -0.0007629819279545917 -0.0007853778648902183 -0.0007454718223754854 -0.0008837911948911938 -0.000687814403267672 -0.0006817594112591003 -0.0007341134383526257 -0.0008456639065096693; 0.00021233478627238387 9.32618809328995e-5 0.00033466948735778244 0.00018578359664727926 0.0002191518519506535 0.00017153919760419414 0.00037179483854547795 0.0002630895408376468 0.0003447877944349889 0.0002982982107316045 0.0001957094823716738 0.00017359283311172032 0.000372190867356755 0.0002973527358351798 0.0003010800738000428 0.00039449378436994616 0.0003353601045246481 7.259458359222763e-5 0.0003340822532792041 0.00015580422961647243 0.00023582457282409484 5.037627543844286e-5 0.0002665133412131929 0.00029024983485525705 0.0003236668874506661 0.00016888407378369604 0.00016594532771380837 8.984108770345024e-5 0.00024305646483691585 0.0001044188368554802 0.0002022031379187777 0.00019705069381599933], bias = [-0.0006696101528693621, 0.00023034508678988706]))
Visualizing the Results
Let us now plot the loss over time
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")
lines!(ax, losses; linewidth=4, alpha=0.75)
scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)
fig
end
Finally let us visualize the results
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
s3 = scatter!(
ax,
tsteps,
waveform_nn_trained;
marker=:circle,
alpha=0.5,
strokewidth=2,
markersize=12,
)
axislegend(
ax,
[[l1, s1], [l2, s2], [l3, s3]],
["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
position=:lb,
)
fig
end
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 128 × AMD EPYC 7502 32-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 128 default, 0 interactive, 64 GC (on 128 virtual cores)
Environment:
JULIA_CPU_THREADS = 128
JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 128
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.