Training a Neural ODE to Model Gravitational Waveforms
This code is adapted from Astroinformatics/ScientificMachineLearning
The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl
Package Imports
using Lux,
ComponentArrays,
LineSearches,
OrdinaryDiffEqLowOrderRK,
Optimization,
OptimizationOptimJL,
Printf,
Random,
SciMLSensitivity
using CairoMakie
Define some Utility Functions
Tip
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector
function one2two(path, m₁, m₂)
M = m₁ + m₂
r₁ = m₂ / M .* path
r₂ = -m₁ / M .* path
return r₁, r₂
end
Next we define a function to perform the change of variables:
@views function soln2orbit(soln, model_params=nothing)
@assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
if size(soln, 1) == 2
χ = soln[1, :]
ϕ = soln[2, :]
@assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
p, M, e = model_params
else
χ = soln[1, :]
ϕ = soln[2, :]
p = soln[3, :]
e = soln[4, :]
end
r = p ./ (1 .+ e .* cos.(χ))
x = r .* cos.(ϕ)
y = r .* sin.(ϕ)
orbit = vcat(x', y')
return orbit
end
This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d_dt(v::AbstractVector, dt)
a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
b = (v[3:end] .- v[1:(end - 2)]) / 2
c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
return [a; b; c] / dt
end
This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d2_dt2(v::AbstractVector, dt)
a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
return [a; b; c] / (dt^2)
end
Now we define a function to compute the trace-free moment tensor from the orbit
function orbit2tensor(orbit, component, mass=1)
x = orbit[1, :]
y = orbit[2, :]
Ixx = x .^ 2
Iyy = y .^ 2
Ixy = x .* y
trace = Ixx .+ Iyy
if component[1] == 1 && component[2] == 1
tmp = Ixx .- trace ./ 3
elseif component[1] == 2 && component[2] == 2
tmp = Iyy .- trace ./ 3
else
tmp = Ixy
end
return mass .* tmp
end
function h_22_quadrupole_components(dt, orbit, component, mass=1)
mtensor = orbit2tensor(orbit, component, mass)
mtensor_ddot = d2_dt2(mtensor, dt)
return 2 * mtensor_ddot
end
function h_22_quadrupole(dt, orbit, mass=1)
h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
return h11, h12, h22
end
function h_22_strain_one_body(dt::T, orbit) where {T}
h11, h12, h22 = h_22_quadrupole(dt, orbit)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
h11 = h11_1 + h11_2
h12 = h12_1 + h12_2
h22 = h22_1 + h22_2
return h11, h12, h22
end
function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
# compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2
@assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"
h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
@assert mass_ratio ≤ 1 "mass_ratio must be <= 1"
@assert mass_ratio ≥ 0 "mass_ratio must be non-negative"
orbit = soln2orbit(soln, model_params)
if mass_ratio > 0
m₂ = inv(T(1) + mass_ratio)
m₁ = mass_ratio * m₂
orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
return waveform
end
Simulating the True Model
RelativisticOrbitModel
defines system of odes which describes motion of point like particle in Schwarzschild background, uses
where,
function RelativisticOrbitModel(u, (p, M, e), t)
χ, ϕ = u
numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
denom = sqrt((p - 2)^2 - 4 * e^2)
χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = numer / (M * (p^(3 / 2)) * denom)
return [χ̇, ϕ̇]
end
mass_ratio = 0.0 # test particle
u0 = Float64[π, 0.0] # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4) # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize) # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e
Let's simulate the true model and plot the results using OrdinaryDiffEq.jl
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)
axislegend(ax, [[l, s]], ["Waveform Data"])
fig
end
Defining a Neural Network Model
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model
that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but in case you do use them, make sure to mark them as const
.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl
,
const nn = Chain(
Base.Fix1(fast_activation, cos),
Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-3.9736744f-5; -8.100223f-5; -0.00012865887; 0.00011418679; -0.0002193272; 1.1995103f-5; 6.9720474f-5; -3.280083f-5; -8.3711406f-5; 5.154774f-5; 8.354386f-5; 3.9065733f-5; 1.9090638f-5; -2.2447792f-5; -2.321735f-5; -3.515181f-5; -3.482303f-5; 3.3506993f-5; 0.00014155087; -0.00019003934; 2.3638975f-5; 0.00016343025; -8.934215f-6; -0.0001254526; -7.114121f-5; -8.1508995f-5; 8.804455f-5; -0.00012339432; -3.8679584f-5; -1.5167665f-5; -0.00010396337; 1.6909087f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[1.7935725f-5 -0.00015549544 -7.751984f-5 -3.456003f-5 -1.3655149f-5 0.00017112146 7.07627f-5 7.1257506f-5 -3.25645f-5 6.9467824f-5 -3.29976f-5 0.00020027038 6.142229f-5 0.0001485865 -0.00013388951 -2.6031628f-5 -3.7917291f-6 -7.5430384f-5 4.2730044f-5 -0.00011818168 2.9986739f-5 -0.00028363767 3.037674f-5 -5.659034f-5 -1.0186567f-5 7.038242f-5 0.00018951524 -1.3182084f-6 2.7812798f-6 5.863221f-5 4.8647882f-5 0.00022037084; 7.727802f-5 5.021629f-5 0.00014215699 -0.00010077979 -1.5439324f-5 -5.672547f-5 2.1785287f-5 -2.3377897f-5 -0.00020950673 -9.186197f-5 6.2202154f-5 0.0001402641 6.056759f-5 0.0001306753 7.873351f-5 9.0749934f-5 -5.235423f-5 0.0001261237 8.690858f-5 -9.432747f-5 -4.9440372f-5 -5.570978f-5 0.000117910145 0.00014855294 -0.00023869355 0.00017070935 -4.4799697f-5 -5.914114f-5 -0.00015890242 0.0002226505 -0.00011167871 3.218813f-5; 0.00012429125 -2.4568852f-5 1.0810708f-5 6.421562f-5 9.9294884f-5 3.8953402f-7 0.00021608997 -0.0001560444 -0.0002696381 8.4502666f-5 0.00020474572 0.00011345782 2.8734015f-5 -3.6770987f-5 0.00012452422 -8.277444f-6 0.00013250227 -7.60699f-5 -4.065215f-5 0.00031907443 -0.00012368236 -1.0401008f-5 0.00016737798 -0.00014154108 0.00016421794 7.127389f-5 0.0002474245 -0.00013711116 5.299611f-5 0.00014759206 -1.9560075f-5 -0.00010989505; 0.00010891673 2.2201164f-5 4.0307328f-5 -9.787538f-5 -9.4647556f-5 2.9063125f-5 9.367381f-5 -0.00013454058 7.690227f-5 4.0639756f-5 0.00018190539 -3.436958f-7 -6.6984154f-5 -4.4521163f-5 7.3254714f-6 0.00017981823 6.283147f-5 -5.3407413f-5 -6.423371f-5 -0.00020602577 -0.00010924324 0.000107218 3.866338f-5 0.00020267682 3.6803885f-6 -4.585936f-5 -3.4023553f-5 -2.5157306f-5 -0.00010921093 -0.00011025234 9.168811f-5 -8.258523f-5; 0.00019495355 -0.000111747344 7.459559f-5 -7.554976f-5 7.5442396f-5 4.96797f-5 1.4421968f-5 6.790518f-5 0.00015888327 6.224219f-5 -3.0592382f-5 6.441116f-5 0.0002338929 3.492728f-5 -0.00020939192 -1.2428464f-5 7.6491706f-5 -5.2979216f-5 6.1117644f-5 -5.4637553f-6 0.00013458199 -9.844644f-5 0.00027515506 -8.654733f-5 2.2068827f-5 -0.00014948247 1.4715229f-5 -4.724503f-5 -4.8545662f-5 -3.7462465f-5 5.5683795f-5 7.033931f-5; 0.000106490486 -7.62674f-5 -6.675236f-5 0.00018577841 0.00011445821 -6.526331f-5 -0.0002379662 -8.5474945f-5 5.486302f-5 2.0650652f-5 -0.00010815091 0.00016901849 -1.328619f-5 -0.00018880476 0.000107160566 -0.00017517131 4.9038135f-5 -8.5456566f-5 -0.00030721282 4.8688355f-5 2.5112908f-5 -1.1795071f-5 0.00014874118 4.79654f-5 -4.5615775f-6 -1.0598759f-5 3.755133f-5 -0.00010364145 -3.6278068f-6 -0.00013439609 1.8764122f-5 2.1076303f-5; -1.015282f-5 -0.0001510425 0.00015956115 2.412534f-5 -1.5767988f-6 -0.00014400155 2.4344929f-5 -4.5383663f-6 0.00014828047 0.00012543406 0.00011074537 4.7981208f-5 -0.00017192384 -6.9453556f-5 0.00014734645 -7.227704f-5 0.00014774154 0.00020048008 0.0001057396 -8.180788f-5 0.00016365836 8.460146f-5 -7.653607f-5 -1.6959906f-5 -1.7926797f-5 2.4376572f-5 -3.0097133f-6 6.682445f-5 -4.8478974f-5 -0.0001260667 7.2175106f-5 -0.00011875615; 9.7741315f-5 9.143036f-5 0.00010313716 -0.00024314101 -0.00012778386 -5.287856f-5 -0.00013602867 0.00017402423 -3.0037927f-5 -7.178004f-5 -1.9487002f-6 -3.9755443f-5 -3.3175907f-5 -7.2101546f-5 1.4207086f-7 -0.0001922183 3.559498f-5 1.3678757f-6 -6.722897f-5 8.1183665f-5 0.00017924941 1.1250029f-5 -0.00010614631 -7.374786f-5 -0.00013431798 -5.3986907f-5 5.0388153f-5 -0.00018913577 -8.088461f-5 0.00011323029 0.00015471451 -0.0001034969; -3.980983f-5 -6.843101f-5 -0.00021205022 9.081242f-6 -0.0001873739 8.2450184f-5 -0.00012325883 -9.143445f-6 4.684858f-5 -5.4758828f-5 0.000151162 1.216747f-5 -8.247985f-6 0.00016831297 -0.00019183438 -2.741226f-5 0.00020479888 -1.418336f-5 -0.00018088032 3.636795f-5 -7.0115093f-6 6.165457f-5 3.7600315f-5 6.944206f-5 -2.2611983f-5 0.00012401726 -0.00016369458 -0.00011543884 0.00012676278 2.0249274f-6 -3.2962005f-5 9.7823424f-5; 6.885219f-5 -7.260459f-5 -3.5767366f-6 -3.834479f-5 -1.1532808f-6 4.7001304f-5 -2.1772361f-5 -0.000103855884 0.00022243707 0.00015504746 -0.00015220763 6.9010857f-6 1.2141809f-5 -0.000100387886 -0.00019071305 0.00014079106 1.5449894f-5 5.548808f-5 -3.7774815f-5 -1.2112292f-5 3.681046f-5 3.7620954f-5 -2.2805874f-5 0.00015861297 -4.4536413f-5 7.279161f-5 7.395921f-5 -2.0521502f-5 4.9831662f-5 -3.2674063f-5 -1.4486087f-5 0.00015998687; -5.35221f-6 -3.1185806f-5 -6.1432f-5 -1.3508322f-5 2.1115573f-6 0.00020758752 -0.00015062817 6.5729755f-6 8.668293f-5 5.5566784f-6 -0.00012610843 0.000116010946 2.2437527f-5 -0.00010391088 -3.9931638f-5 -2.3191546f-5 -9.417664f-6 -2.9528037f-5 -0.00017079413 -6.7500114f-5 0.00011664518 -0.00014286564 -0.00032444936 1.2223811f-5 9.399535f-5 9.172896f-6 -1.5347601f-5 -2.4663332f-5 0.00011386941 -3.815818f-5 3.0146616f-6 9.273545f-6; 8.272339f-5 -4.6101617f-5 -8.2737926f-5 3.975672f-5 0.00016296859 -7.748756f-5 -0.00015438999 -4.2720072f-5 7.252678f-5 3.3689357f-5 -2.3169952f-5 -0.0001261189 -2.9482811f-5 8.881661f-5 7.582702f-6 5.9693648f-5 -9.24766f-6 -5.0919367f-5 5.9779272f-6 -9.455208f-5 -1.23216005f-5 7.902919f-5 1.3856021f-5 -0.00017759789 -4.3848762f-5 0.00011858444 1.442277f-5 0.0001521001 -0.00013171254 -0.00013680659 8.259527f-5 -0.00010233997; 9.854694f-5 2.5517393f-5 6.632436f-5 -4.590653f-6 2.8067787f-5 4.1295876f-5 1.0124318f-5 -1.20012855f-5 -0.00031809235 3.22753f-5 3.848675f-5 0.00012130122 -3.2094882f-5 -0.00011378871 0.00017943654 0.00015649761 -4.2903375f-5 -3.6823196f-5 3.00242f-5 -5.8376696f-5 9.111382f-5 1.666787f-5 0.00011616121 7.652154f-5 3.6276455f-5 -0.00011036569 0.00010652493 4.0326187f-5 -1.4382916f-5 -2.5103778f-5 -0.00017325795 9.088441f-5; -0.0001747127 1.4795508f-5 -4.696427f-6 4.8482772f-5 0.00014615073 -0.00010627147 -5.2725733f-5 -1.4206495f-5 7.769453f-5 -0.00019420765 -0.00013065671 4.3137666f-6 -0.00013215104 -0.00020616285 1.2967698f-5 -0.00018007695 5.7576f-5 2.3632658f-6 6.970247f-5 -2.4290854f-5 3.3510743f-5 1.7983795f-5 -9.28546f-6 -7.923879f-5 8.7215f-5 4.7965983f-5 -6.490151f-5 -4.5572273f-5 -1.9772124f-5 -1.6623453f-5 3.0901996f-5 -4.5646208f-5; 3.3442855f-5 -1.1659241f-5 -3.2423936f-6 0.00011058658 7.637433f-5 0.00012147321 0.00021944691 0.00012874532 3.3710246f-5 -3.7968217f-5 3.6531845f-5 1.0580579f-5 -7.670703f-6 -7.8012825f-5 -0.00021488236 -4.359181f-5 5.1897136f-6 -0.00016289006 -7.175076f-5 -8.845499f-5 -0.00013287834 -0.0001772602 -0.00010851375 7.281635f-5 6.520749f-5 7.33973f-5 8.936558f-5 8.72339f-5 -2.8804772f-5 -2.8753764f-5 9.944154f-5 7.8596604f-5; 0.00020800365 3.1968073f-5 4.0540515f-5 -4.2394804f-5 9.0594236f-5 0.00011556105 -3.0601714f-5 -9.494465f-5 0.00010507307 -9.742889f-5 4.834621f-6 0.00023518635 2.0035799f-5 3.7502057f-5 -0.000100294084 -2.7474363f-5 5.3102143f-5 9.420054f-5 -1.38739715f-5 -5.17223f-5 8.488759f-5 2.3341416f-5 -7.18643f-5 -6.717143f-6 6.725135f-5 4.4466433f-5 2.2151203f-6 -6.32962f-5 0.00014194718 -9.899811f-5 0.0001418153 -1.4105099f-5; -4.7051497f-5 8.3250364f-7 -0.00015998556 0.000119440716 0.00012548229 6.566521f-5 4.8309892f-5 -0.000112106805 -3.5077217f-5 -0.00011596079 2.586442f-5 0.00013492866 -0.00016440154 9.783396f-5 -0.000104227336 6.488869f-5 -3.057006f-5 5.4173852f-5 5.2521507f-5 -0.000120383054 6.290582f-5 8.306425f-5 -4.7814005f-6 0.00023743036 2.4963927f-5 -5.197769f-5 6.333012f-5 -2.3245959f-5 0.00015308794 7.2773655f-5 9.507123f-5 -5.1025036f-5; -0.00010602389 -0.00024266628 -6.9977286f-5 0.0002121846 9.1075286f-5 -0.00014485925 0.00022440338 -7.1360584f-5 4.228677f-6 3.821319f-5 0.000110424546 -0.00010622903 -1.596922f-5 -0.00016731954 0.00013205818 -4.897999f-6 2.7242111f-5 -0.0002197059 -2.9089775f-5 2.3026652f-5 -0.00014071012 8.86391f-5 -9.805274f-5 -8.8937435f-5 8.06138f-5 -0.00012573933 0.00017853314 3.4971124f-6 3.4160963f-5 -1.4420716f-5 -0.00012142252 -0.00011472782; -5.6367797f-5 0.0001213166 0.00010276994 5.777688f-5 -4.8692473f-5 -0.00019520977 7.2912044f-5 -2.9678175f-5 -5.324345f-5 -2.3004084f-6 -3.7593625f-5 0.00018278002 0.00011677425 5.872141f-5 0.00020279993 -0.0002041241 0.00015618985 8.074553f-5 4.381168f-5 4.1666895f-5 4.7414337f-6 5.892409f-5 -2.3721532f-5 1.254132f-5 -0.00014061481 -2.287063f-6 1.2382225f-5 -2.7832868f-5 9.316681f-5 -5.5036653f-5 -0.00010151345 -9.856252f-5; -1.9427096f-5 -7.766219f-6 -2.830048f-5 3.6490404f-5 -2.297035f-5 -7.506421f-5 -5.7179423f-5 -7.4704025f-5 4.768916f-5 6.431386f-5 7.106304f-5 -7.410735f-5 2.6564941f-5 -4.9502498f-5 -8.674302f-6 -0.000113571914 -0.00013141891 5.471143f-5 9.941384f-5 -2.5108407f-5 -6.584709f-5 0.00021743137 -4.7814232f-5 7.4191834f-5 -0.000111939866 -3.3085973f-5 7.421025f-5 -0.00013186461 0.000107782056 -1.0400752f-5 -0.00012015392 -9.223977f-5; -8.6742504f-5 5.7083013f-5 0.00017299283 1.1994846f-5 -7.334117f-5 6.295701f-6 -0.00027150306 -0.00020976654 2.548847f-5 -0.00013512716 2.0273596f-5 1.2770922f-5 -7.639746f-5 -2.1962333f-5 0.00018631334 8.379929f-5 -8.7749715f-5 9.405523f-6 -0.000121040415 0.00018991291 -2.2519276f-5 -9.7013566f-5 4.432927f-6 -5.3389576f-5 0.0001551898 -1.9281988f-5 -0.00011526415 -0.00018083923 -8.7553184f-5 -5.5232496f-5 0.00016566155 0.0001828607; 1.3114239f-5 5.697844f-5 0.00010871241 -2.1206031f-5 0.00014581023 0.00029361446 9.1063026f-5 0.00011048014 -2.5021174f-5 7.211537f-5 -4.5593406f-5 0.00015375293 9.718174f-5 -6.2700245f-5 -6.5053224f-5 -1.619737f-5 -5.9831087f-5 -4.8581867f-5 9.409393f-5 2.0775138f-5 9.1415495f-5 -6.761559f-5 0.00016663749 0.00012878155 -8.386769f-6 -7.051302f-5 -7.018555f-5 0.00023534163 8.263021f-6 -0.00011806475 -3.046332f-5 0.0002738205; -3.8068218f-5 0.00013078316 -9.942488f-5 -2.6031921f-5 -6.87008f-5 -6.950731f-5 6.620108f-7 1.1178116f-5 7.2207244f-5 -8.454775f-5 5.727153f-6 2.0703785f-5 -0.00017193936 1.584029f-5 3.0484201f-5 7.216239f-5 2.4618224f-5 -4.6751793f-5 -0.0001323439 0.00012953262 -5.127127f-5 7.0064916f-6 -9.7757555f-5 -8.589784f-5 -0.00013621684 0.0002258606 0.000111529924 -0.00014347097 -3.3989238f-6 0.000146742 -0.000112785 -3.427244f-5; 8.558374f-5 -1.7098246f-5 -5.8415375f-5 7.929349f-5 -0.000112337104 0.00010960261 2.3558889f-5 -9.686581f-5 9.175191f-5 -3.846829f-5 0.00020403575 -0.0001195645 -7.577944f-5 -5.141976f-5 -1.2939917f-5 -6.697691f-6 -1.9087225f-5 3.2954435f-5 -7.1998475f-5 -0.00016200074 -6.735023f-5 -0.00012367815 -0.00018851546 -3.1732856f-5 -0.00012298644 0.00013595933 0.0001528831 2.5186435f-5 -2.4311734f-5 0.00016896709 -1.7739776f-5 0.0001770034; 2.7498088f-5 0.00019149865 8.9321016f-5 3.638526f-6 0.00011158267 2.817006f-5 -5.1551415f-5 -7.4257296f-5 4.4699632f-7 0.00011901795 0.00014610398 2.4414669f-5 9.0392736f-5 -0.00023403177 4.96439f-5 -6.2660445f-5 -0.00018630539 -0.00010112284 7.87026f-5 -9.822152f-7 -5.963788f-5 1.9174235f-5 5.5599456f-5 -1.434045f-5 7.437505f-5 0.00025489423 3.3934622f-5 2.1354992f-5 -0.0001139659 -7.296706f-5 -0.00018131293 -9.879685f-5; -0.00011729412 -1.9600831f-5 -0.00011695361 -0.00014271382 2.7192344f-5 -5.052609f-5 2.1017051f-5 8.101092f-5 -9.264837f-5 -5.6063098f-5 0.00013951196 -5.065976f-5 -0.00021209652 7.177738f-6 -0.000106139836 -7.50996f-5 -5.6091063f-5 -0.00020617903 -5.5091987f-6 0.00019275922 -3.8726404f-5 2.3755614f-5 4.2543984f-6 -1.27321955f-5 -8.03364f-5 -6.0210044f-5 3.41482f-5 -0.00013459625 -0.00015724981 9.530918f-5 -0.00011453899 -2.9754918f-5; 0.00010490447 2.514199f-5 1.5804166f-5 -0.00014304579 -5.1772684f-5 -6.0245155f-5 5.0619824f-6 -0.00012987637 9.751308f-5 -5.2114294f-5 0.00022345736 6.492333f-5 -8.5176216f-5 -4.274854f-5 5.2399297f-5 0.000102536294 0.0001636009 3.788329f-5 -1.402582f-5 -8.465247f-5 -8.694876f-5 0.0001479338 8.2247025f-6 -5.9899907f-5 2.6614147f-5 7.631466f-5 -2.8850503f-5 -0.00018061238 9.199061f-5 6.280378f-5 -5.5494984f-5 -3.65859f-6; 8.975269f-5 -4.8634916f-5 -0.00020223504 -7.519425f-5 -0.0001542233 -3.7681464f-5 -6.678656f-5 8.370127f-6 -1.760174f-5 -0.000121576275 -5.4713753f-5 -6.412953f-5 -0.0001925484 -0.000119968114 -4.3175947f-5 8.882638f-5 -0.0001227838 -0.00013051355 -0.00012664891 9.957529f-5 2.3088285f-5 -0.00021928336 5.3138065f-5 -6.476576f-5 1.2892521f-5 -5.722506f-5 0.00012990796 -9.154232f-6 -0.00015475905 -5.0260533f-5 -1.1767132f-5 6.328413f-5; 0.0001642891 -6.6397246f-5 0.00013094347 6.515208f-5 -0.00010935248 6.112396f-5 0.00016905996 -1.3631593f-5 -0.00011789206 5.635014f-6 -4.5228415f-5 -3.3471554f-6 0.00013017899 0.00014444056 -2.465046f-5 -0.000113306436 5.9781698f-5 6.277781f-5 6.564196f-5 0.00013470482 -3.0108393f-5 -6.003321f-5 9.126402f-5 -7.766654f-6 4.0213956f-5 -0.00021215278 5.6486697f-5 4.3812153f-7 4.021726f-5 0.00018086613 -9.956863f-5 -0.00019128446; 5.8786132f-5 6.810262f-5 -6.950382f-5 -3.7951555f-5 -0.00014261453 0.0002321655 0.00022871306 3.3061544f-5 -4.211326f-5 -0.00012407814 4.7088262f-5 9.356793f-5 -0.00012596663 8.4263724f-5 2.3046352f-5 0.0001700126 8.745044f-6 0.00010187357 -5.261187f-5 -2.5072062f-5 -4.1749074f-5 0.00010778945 3.8240134f-5 2.2977878f-5 -4.2108753f-5 -0.00010838058 -4.9285212f-5 -4.0719977f-5 -3.877996f-5 0.00011554357 -0.00019508893 -6.144479f-5; 9.477486f-5 -7.5954677f-6 8.090404f-6 -6.2413514f-5 0.00029564332 -5.288918f-5 -9.943804f-5 5.2523028f-5 -8.9570885f-6 -1.8171686f-5 -8.1576756f-5 0.0001122504 6.530863f-5 0.00012861838 1.6694352f-5 -9.806201f-5 1.19108945f-5 -0.00013382781 -0.00010908083 9.6556265f-5 -4.5278135f-5 -0.000116522104 -7.991439f-5 -3.688171f-5 2.9862586f-5 -3.7474136f-5 -1.2007337f-5 -0.0001445594 2.4633888f-5 -2.8899742f-5 -0.0001898848 -2.4811973f-6; -2.2112126f-5 7.263129f-5 -6.698743f-5 -0.0001478865 1.759168f-5 -0.00020800946 -5.8757065f-5 -0.00014305668 -1.6120912f-5 4.0240735f-5 6.921449f-5 -1.3652434f-5 -3.967269f-5 6.137593f-5 0.00029785844 -4.1532854f-5 -7.023287f-5 1.6445954f-5 -5.6175955f-5 4.2813946f-5 8.364935f-5 -0.00011766804 -0.00013397758 -0.00014323358 -8.2808f-5 -5.7161316f-5 -0.00012373031 2.1950935f-5 -8.2913386f-5 -5.2223873f-5 0.00020733244 4.836751f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[3.241207f-5 0.00017244759 -8.289667f-5 8.9669156f-5 -1.5162434f-5 -4.7221647f-7 4.9921706f-5 -3.235966f-5 0.00011318885 -2.5753297f-5 9.367696f-5 -0.00012941178 0.00012522342 0.00010627981 -6.853386f-6 0.00015291336 -1.9045325f-5 4.999546f-5 -0.00014529955 0.00013873009 3.3928693f-5 5.6186367f-5 -0.00012860051 -2.8889954f-5 -5.590577f-5 9.585365f-6 9.592844f-5 -4.6192894f-5 -3.0267376f-5 -9.41925f-5 7.678082f-5 -4.06406f-5; 2.8664377f-5 -0.00012044511 -5.9711758f-5 -0.00017857346 -7.4254574f-5 7.920887f-5 -0.00012212896 4.2933956f-5 -0.00015043441 -7.323298f-5 9.965944f-5 6.720936f-5 -6.521525f-5 5.98579f-5 3.3309043f-5 4.1723488f-5 -0.00014542993 -0.00011068163 1.1274403f-5 -0.00010492308 3.0720857f-5 5.9517646f-5 0.000201299 -9.149326f-6 -3.7323487f-5 0.00013634357 -7.348439f-5 7.691579f-5 4.739382f-5 0.000105898274 2.2137496f-5 -3.4222452f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))
Similar to most deep learning frameworks, Lux defaults to using Float32
. However, in this case we need Float64
const params = ComponentArray(f64(ps))
const nn_model = StatefulLuxLayer(nn, nothing, st)
StatefulLuxLayer{Val{true}()}(
Chain(
layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
layer_2 = Dense(1 => 32, cos), # 64 parameters
layer_3 = Dense(32 => 32, cos), # 1_056 parameters
layer_4 = Dense(32 => 2), # 66 parameters
),
) # Total: 1_186 parameters,
# plus 0 states.
Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses
where,
function ODE_model(u, nn_params, t)
χ, ϕ = u
p, M, e = ode_model_params
# In this example we know that `st` is am empty NamedTuple hence we can safely ignore
# it, however, in general, we should use `st` to store the state of the neural network.
y = 1 .+ nn_model([first(u)], nn_params)
numer = (1 + e * cos(χ))^2
denom = M * (p^(3 / 2))
χ̇ = (numer / denom) * y[1]
ϕ̇ = (numer / denom) * y[2]
return [χ̇, ϕ̇]
end
Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
axislegend(
ax,
[[l1, s1], [l2, s2]],
["Waveform Data", "Waveform Neural Net (Untrained)"];
position=:lb,
)
fig
end
Setting Up for Training the Neural Network
Next, we define the objective (loss) function to be minimized when training the neural differential equations.
const mseloss = MSELoss()
function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
return mseloss(pred_waveform, waveform)
end
Warmup the loss function
loss(params)
0.0007361150812231304
Now let us define a callback function to store the loss over time
const losses = Float64[]
function callback(θ, l)
push!(losses, l)
@printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
return false
end
Training the Neural Network
Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
optprob,
BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
callback,
maxiters=1000,
)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-3.97367439290067e-5; -8.100223203654136e-5; -0.00012865886674224983; 0.00011418679059706655; -0.00021932719391743637; 1.1995102795443159e-5; 6.972047412976552e-5; -3.280082819402246e-5; -8.37114057504512e-5; 5.154774044047748e-5; 8.354386227432994e-5; 3.9065733289916716e-5; 1.909063757915007e-5; -2.2447791707201297e-5; -2.3217349735226695e-5; -3.515181015241472e-5; -3.482303145569123e-5; 3.3506992622201456e-5; 0.00014155087410433733; -0.00019003933994122158; 2.3638975108031355e-5; 0.0001634302461751744; -8.934214747542124e-6; -0.0001254525996043802; -7.114120671753033e-5; -8.150899520840987e-5; 8.804455137574985e-5; -0.00012339431850681037; -3.867958366746796e-5; -1.5167664969336369e-5; -0.00010396337165731191; 1.6909087207734646e-5;;], bias = [-1.0331137997516366e-16, -1.3141005002124678e-17, -4.276787990783788e-16, 1.0598656650379092e-16, -3.106164314191496e-16, 4.323962983342471e-17, 2.7717193763773193e-16, 1.1233627758844613e-17, 1.0028736594811233e-16, 1.1993641856539512e-16, 6.031097014880817e-17, 1.4627440027929266e-16, 6.910117990457994e-17, -3.9664245260695566e-17, 2.4800175294853224e-17, -4.6951660069301154e-17, -6.07341900682269e-17, 8.76347803896894e-17, 4.392960865542354e-16, 2.5289291150161677e-16, 1.2517064878551467e-17, 1.0520302317357034e-16, -3.081375037585046e-17, -2.663033898228002e-16, -6.217566674154782e-17, -1.6115140530832142e-17, 4.533455442126217e-17, -2.72396487152735e-16, -8.724213795193438e-17, -4.429684716395399e-18, 1.3988033922784207e-17, 1.7626750990553992e-17]), layer_3 = (weight = [1.793767206438615e-5 -0.00015549348880993256 -7.751789506981836e-5 -3.455808381657581e-5 -1.3653201667700361e-5 0.00017112340447968225 7.076464408785159e-5 7.125945286136415e-5 -3.2562554215454383e-5 6.94697710150872e-5 -3.2995651954459824e-5 0.0002002723313217078 6.142423536405308e-5 0.00014858845407618543 -0.0001338875617784415 -2.602968105745682e-5 -3.789781876525418e-6 -7.54284364938707e-5 4.2731991025004743e-5 -0.00011817972971584718 2.9988686046659506e-5 -0.0002836357189062185 3.037868646972706e-5 -5.658839198984842e-5 -1.0184619886765299e-5 7.038436616314517e-5 0.00018951718539940066 -1.3162611958107312e-6 2.7832270552838563e-6 5.863415663955152e-5 4.864982930486699e-5 0.0002203727899632451; 7.72797959405701e-5 5.021806848963951e-5 0.0001421587678385585 -0.00010077801359438483 -1.543754640007687e-5 -5.672369095117188e-5 2.1787064023384366e-5 -2.337612001718875e-5 -0.00020950495470168796 -9.186019248880925e-5 6.220393172844518e-5 0.00014026588380949156 6.056936602254984e-5 0.0001306770738927426 7.873528577766958e-5 9.075171167643347e-5 -5.235245204965979e-5 0.00012612548219186797 8.691035530209929e-5 -9.43256943821402e-5 -4.943859473422528e-5 -5.5708003648035205e-5 0.00011791192285164862 0.0001485547129312221 -0.00023869177473096793 0.00017071112436388842 -4.479791984630588e-5 -5.913936344877155e-5 -0.00015890064039120864 0.00022265227947065267 -0.00011167693243914298 3.218990807801143e-5; 0.00012429647592415794 -2.4563630567916447e-5 1.081592888127772e-5 6.422084386046639e-5 9.930010543949163e-5 3.947550409920257e-7 0.0002160951958741726 -0.00015603918603101903 -0.00026963289149518357 8.450788715907875e-5 0.00020475094476767578 0.00011346304069792614 2.8739236377253756e-5 -3.6765766189825396e-5 0.0001245294375350506 -8.272223048887213e-6 0.0001325074959551711 -7.606467573151161e-5 -4.0646929328289314e-5 0.0003190796529180502 -0.0001236771380989074 -1.0395786766690393e-5 0.0001673831980680862 -0.0001415358596476248 0.00016422316236474645 7.127911209880197e-5 0.00024742971852008823 -0.00013710593650997398 5.300133177133047e-5 0.00014759727907421053 -1.955485407036222e-5 -0.00010988983142834181; 0.00010891676530884204 2.2201202152083535e-5 4.0307366286307355e-5 -9.7875343347531e-5 -9.464751754710884e-5 2.9063163234988216e-5 9.367384510989853e-5 -0.00013454053813182195 7.690231169983363e-5 4.063979387576315e-5 0.00018190542497615878 -3.4365767935790297e-7 -6.698411569521203e-5 -4.452112451251971e-5 7.325509538369699e-6 0.00017981827287878464 6.283150461678951e-5 -5.3407375238001375e-5 -6.423367274981391e-5 -0.0002060257302975324 -0.00010924320493648202 0.0001072180402087007 3.866341823306492e-5 0.00020267685787695998 3.6804266681623827e-6 -4.58593222748992e-5 -3.402351474794942e-5 -2.5157267448938657e-5 -0.00010921089240871743 -0.00011025230022203549 9.168814896549643e-5 -8.258519490771368e-5; 0.0001949569120137406 -0.00011174398271177383 7.459895248247659e-5 -7.554639917143817e-5 7.544575753341114e-5 4.968306136507825e-5 1.4425329640659041e-5 6.790854210922677e-5 0.00015888662862780965 6.224554788278425e-5 -3.058902104894163e-5 6.441451812353119e-5 0.00023389626555800697 3.4930642572736956e-5 -0.00020938855625498798 -1.2425103091654112e-5 7.64950670367496e-5 -5.29758547832799e-5 6.112100497774746e-5 -5.460394118591695e-6 0.00013458535220198573 -9.844308108369653e-5 0.0002751584249148547 -8.654396544953372e-5 2.207218828917227e-5 -0.00014947910706323067 1.4718589840779645e-5 -4.724166894957859e-5 -4.8542300942780946e-5 -3.745910391952567e-5 5.568715612243908e-5 7.0342670073307e-5; 0.00010648823338003222 -7.626965559877694e-5 -6.675461661627017e-5 0.00018577615521524343 0.00011445595994037608 -6.526556278670918e-5 -0.000237968456692513 -8.547719806058975e-5 5.486076663389597e-5 2.06483990247495e-5 -0.00010815316379860251 0.0001690162342335223 -1.3288442621461544e-5 -0.00018880700843267829 0.00010715831269650968 -0.00017517356643777972 4.903588235053533e-5 -8.545881899165628e-5 -0.00030721506953097126 4.868610160219562e-5 2.5110654699045877e-5 -1.1797324238895992e-5 0.00014873892616621432 4.796314790172011e-5 -4.563830477606962e-6 -1.0601012197481461e-5 3.7549077044491097e-5 -0.00010364369989230829 -3.630059770630617e-6 -0.00013439834326830466 1.87618687256666e-5 2.10740498211351e-5; -1.0150453887419053e-5 -0.00015104013989587255 0.00015956352028621585 2.4127706568441655e-5 -1.5744323731820788e-6 -0.00014399918116070708 2.4347294969243996e-5 -4.535999847958851e-6 0.0001482828319454475 0.0001254364269003215 0.00011074773822071991 4.798357433499286e-5 -0.0001719214688604737 -6.945118918922204e-5 0.00014734881726651979 -7.227467547310976e-5 0.00014774390176496943 0.00020048244361858364 0.00010574196669364305 -8.180551055961636e-5 0.00016366072843409667 8.460382552908178e-5 -7.653370091831682e-5 -1.695753969800836e-5 -1.792443079141257e-5 2.4378938108906707e-5 -3.007346846449367e-6 6.682681298440647e-5 -4.8476607913854674e-5 -0.00012606433854345613 7.217747217834664e-5 -0.00011875378375081732; 9.773823922620954e-5 9.142728459444269e-5 0.00010313408211694546 -0.00024314408573665564 -0.00012778693552934663 -5.2881635831392316e-5 -0.0001360317451075631 0.000174021154186966 -3.00410024396335e-5 -7.178311852151549e-5 -1.9517756231145962e-6 -3.9758518589615316e-5 -3.317898289645342e-5 -7.21046212606156e-5 1.3899541051489146e-7 -0.00019222137552099463 3.5591906005328715e-5 1.3648002619168147e-6 -6.723204341361433e-5 8.1180589866206e-5 0.00017924633949184556 1.1246953824954834e-5 -0.00010614938718611656 -7.37509376380463e-5 -0.00013432105105714656 -5.398998245475433e-5 5.038507743627325e-5 -0.00018913884331815561 -8.08876843552119e-5 0.0001132272177597481 0.00015471143206703344 -0.00010349997818814852; -3.9810846507760515e-5 -6.843202267817692e-5 -0.00021205123325388088 9.080226257643417e-6 -0.00018737491031685878 8.24491683971348e-5 -0.00012325984120878922 -9.144460497635405e-6 4.6847565820929695e-5 -5.475984348494226e-5 0.00015116099107500641 1.2166454646932586e-5 -8.249000208605768e-6 0.00016831195112273356 -0.00019183539247650394 -2.7413275026345563e-5 0.00020479786631141139 -1.4184375365301474e-5 -0.00018088133642493774 3.636693450346159e-5 -7.012524881146076e-6 6.165355748919031e-5 3.7599299824352056e-5 6.944104402747768e-5 -2.261299837017486e-5 0.00012401624130972195 -0.00016369559729981667 -0.0001154398529678322 0.00012676176570777141 2.023911841307618e-6 -3.2963020791507834e-5 9.782240827290596e-5; 6.885411194123895e-5 -7.260266947962089e-5 -3.5748151280383823e-6 -3.834286990279482e-5 -1.1513592884962178e-6 4.7003225244755916e-5 -2.1770439434858137e-5 -0.0001038539624463893 0.00022243899065831904 0.00015504937682597468 -0.00015220571303909414 6.903007215062285e-6 1.2143730327428244e-5 -0.0001003859645610798 -0.00019071112847124526 0.00014079298189206637 1.5451815397302943e-5 5.5490000431366645e-5 -3.777289321117692e-5 -1.2110370608244224e-5 3.681238076713569e-5 3.762287515161484e-5 -2.2803952834161823e-5 0.0001586148870952209 -4.453449155373148e-5 7.27935336722831e-5 7.396112841850328e-5 -2.0519580448696464e-5 4.983358367060885e-5 -3.267214133340213e-5 -1.448416542227623e-5 0.00015998879161960162; -5.3547126077391144e-6 -3.11883084355972e-5 -6.143450425551642e-5 -1.351082462460691e-5 2.1090545255300495e-6 0.000207585020893376 -0.0001506306723746138 6.570472782542898e-6 8.6680425942538e-5 5.554175663194715e-6 -0.00012611092804545756 0.00011600844320527545 2.2435024424695262e-5 -0.00010391338562440875 -3.993414045508765e-5 -2.31940482768625e-5 -9.420166657590226e-6 -2.9530539700870857e-5 -0.0001707966291865808 -6.750261652594136e-5 0.0001166426738785878 -0.00014286814237092163 -0.0003244518591378522 1.2221307974327657e-5 9.399285065628213e-5 9.170393445307716e-6 -1.5350103945458937e-5 -2.4665834440016303e-5 0.00011386691060069281 -3.816068399429161e-5 3.0121588901241796e-6 9.27104267647422e-6; 8.272200845826421e-5 -4.6102997500521615e-5 -8.273930673307986e-5 3.975533961497772e-5 0.00016296720620891333 -7.748893933890447e-5 -0.00015439136571071888 -4.2721452543586677e-5 7.252540142221627e-5 3.368797676816636e-5 -2.3171332784069065e-5 -0.0001261202831548452 -2.9484191897749475e-5 8.881523005843118e-5 7.581321527000555e-6 5.969226747115132e-5 -9.249040728858793e-6 -5.0920747368903994e-5 5.976546774814186e-6 -9.455346101016095e-5 -1.2322980935236391e-5 7.902780832669911e-5 1.3854640358518174e-5 -0.00017759926943394642 -4.385014274444268e-5 0.00011858305713108045 1.4421389253447354e-5 0.00015209871541888142 -0.0001317139229533076 -0.00013680796639815027 8.259388612064024e-5 -0.00010234134772613699; 9.854897869158737e-5 2.5519433599443748e-5 6.632640359097301e-5 -4.5886122490253604e-6 2.8069827728191478e-5 4.129791702948807e-5 1.0126358553722805e-5 -1.1999244655407947e-5 -0.000318090311359951 3.2277341212510506e-5 3.8488789847947966e-5 0.00012130326283736358 -3.209284115941913e-5 -0.00011378666751911592 0.00017943857890427383 0.00015649965278212205 -4.2901334402950075e-5 -3.6821154956806e-5 3.0026241781193794e-5 -5.837465507572739e-5 9.111586039293826e-5 1.666991008899575e-5 0.00011616324989231603 7.652357815171785e-5 3.6278495805933785e-5 -0.0001103636496037817 0.00010652697162808644 4.0328228330363024e-5 -1.4380874766311872e-5 -2.51017372858309e-5 -0.000173255913999271 9.088644944936324e-5; -0.00017471642290217118 1.4791780318446761e-5 -4.700154517510107e-6 4.847904479506637e-5 0.0001461470069040805 -0.00010627520036393267 -5.272946013461649e-5 -1.4210222201988406e-5 7.769080316906273e-5 -0.0001942113763510541 -0.00013066044134062312 4.310038962679458e-6 -0.00013215476296351685 -0.00020616657506649757 1.29639706352758e-5 -0.00018008068086088929 5.757227104880554e-5 2.359538216872476e-6 6.96987458065072e-5 -2.4294581220892844e-5 3.3507015774041405e-5 1.7980067655723635e-5 -9.289187386792271e-6 -7.924251489780304e-5 8.721127001593995e-5 4.7962255353596e-5 -6.49052404190643e-5 -4.557600054281287e-5 -1.977585138262508e-5 -1.662718072508875e-5 3.089826846521289e-5 -4.564993518612753e-5; 3.3443484017276527e-5 -1.1658611992665803e-5 -3.241764670444385e-6 0.00011058721024740749 7.637495905357887e-5 0.00012147383999975926 0.00021944754148565666 0.0001287459486323646 3.371087545959665e-5 -3.796758796821803e-5 3.653247363248295e-5 1.058120828960124e-5 -7.670073909742655e-6 -7.801219585386515e-5 -0.0002148817261055805 -4.359117924619892e-5 5.190342608635943e-6 -0.00016288943556870756 -7.175012838130699e-5 -8.845435918344398e-5 -0.00013287771112087594 -0.00017725956827204696 -0.0001085131193118862 7.281697940046175e-5 6.520811762160118e-5 7.339792823616811e-5 8.936621047754212e-5 8.723452951956757e-5 -2.8804143266519197e-5 -2.8753135165664703e-5 9.94421675844335e-5 7.85972329598159e-5; 0.00020800731771433503 3.197174054494237e-5 4.05441828564912e-5 -4.2391136512930606e-5 9.059790362639168e-5 0.00011556471917120683 -3.059804613312393e-5 -9.494098004378546e-5 0.00010507673913653066 -9.74252248039132e-5 4.838288522842943e-6 0.000235190018874535 2.0039466552961726e-5 3.7505724766933305e-5 -0.00010029041687684239 -2.7470695850380697e-5 5.310581081121976e-5 9.420421019004486e-5 -1.3870303983472814e-5 -5.171863404997312e-5 8.489126093028005e-5 2.3345083254184943e-5 -7.186063113348608e-5 -6.713475623301291e-6 6.725501633755217e-5 4.4470100410218415e-5 2.2187878033946354e-6 -6.329252893066023e-5 0.00014195084849180436 -9.899444513447649e-5 0.00014181896448407864 -1.4101431143526717e-5; -4.70490544623081e-5 8.349461316553034e-7 -0.0001599831179711938 0.00011944315842157806 0.00012548472799861455 6.566765436475836e-5 4.8312334495515726e-5 -0.00011210436283106319 -3.5074774548028285e-5 -0.00011595835026780606 2.586686199976724e-5 0.00013493110376914597 -0.00016439910037404203 9.783640193093277e-5 -0.0001042248936366133 6.489113506434136e-5 -3.056761712029923e-5 5.417629430924202e-5 5.252394980089759e-5 -0.0001203806118221148 6.290826199149416e-5 8.306669546327976e-5 -4.778958061408721e-6 0.0002374328010253645 2.4966369476553978e-5 -5.197524884095263e-5 6.333255945975958e-5 -2.324351631358329e-5 0.00015309038057496408 7.277609760679035e-5 9.507366989311585e-5 -5.102259315860731e-5; -0.00010602666793382094 -0.00024266906461668976 -6.998006778991656e-5 0.00021218182314663 9.107250457642781e-5 -0.00014486203349186534 0.00022440060070259493 -7.13633655755815e-5 4.225895511285245e-6 3.821040968652616e-5 0.00011042176425000843 -0.0001062318135587529 -1.597200229987908e-5 -0.00016732231801832486 0.00013205539812808468 -4.900780593362653e-6 2.7239329713137318e-5 -0.00021970867441958156 -2.909255658556733e-5 2.3023870883165228e-5 -0.0001407129041104616 8.86363175163243e-5 -9.805552167343407e-5 -8.894021652628614e-5 8.061101630339416e-5 -0.0001257421120022407 0.0001785303591099594 3.4943309423694664e-6 3.415818150862747e-5 -1.442349753013532e-5 -0.00012142530090165718 -0.00011473060498936516; -5.636634806619663e-5 0.0001213180520480318 0.00010277138870691462 5.7778329288286445e-5 -4.869102405252926e-5 -0.00019520831766754027 7.291349255496773e-5 -2.9676725820583834e-5 -5.324200093496555e-5 -2.2989596473535212e-6 -3.759217644442767e-5 0.0001827814638872474 0.00011677570081333051 5.872285772597762e-5 0.00020280137767799542 -0.00020412264631345558 0.00015619130216323152 8.074697769729738e-5 4.381312909202996e-5 4.1668344134386164e-5 4.742882434537105e-6 5.892554043924465e-5 -2.372008303978115e-5 1.2542768568837475e-5 -0.00014061335888441526 -2.285614176848954e-6 1.2383673389144247e-5 -2.783141927835421e-5 9.316825970558748e-5 -5.503520434472351e-5 -0.00010151199886791553 -9.856107418641093e-5; -1.942899747622799e-5 -7.768120341224733e-6 -2.830238067065975e-5 3.6488503027489536e-5 -2.297225059111201e-5 -7.506610803934038e-5 -5.718132401047207e-5 -7.470592630956488e-5 4.768725680956817e-5 6.431196139497615e-5 7.106113967789535e-5 -7.410925412945589e-5 2.656304000889208e-5 -4.95043992861101e-5 -8.67620350722422e-6 -0.0001135738153981433 -0.00013142081537911827 5.470952724759304e-5 9.941193762539958e-5 -2.5110308925941226e-5 -6.584899041731111e-5 0.0002174294689747274 -4.781613337503973e-5 7.418993241923667e-5 -0.00011194176717357628 -3.3087874399952675e-5 7.420834786795327e-5 -0.00013186651143872407 0.00010778015472065517 -1.0402653643854274e-5 -0.00012015582127891047 -9.224166809976316e-5; -8.67443728819993e-5 5.7081143833882357e-5 0.00017299096486759905 1.1992977644797364e-5 -7.334304121092297e-5 6.293832368195288e-6 -0.0002715049255383684 -0.00020976841063010917 2.5486602109693446e-5 -0.0001351290326309887 2.027172683675259e-5 1.2769053111791989e-5 -7.639932903466269e-5 -2.1964201212445985e-5 0.00018631146792550128 8.379742336920063e-5 -8.775158369453094e-5 9.403654414294885e-6 -0.00012104228392480123 0.00018991104422832702 -2.2521144844991832e-5 -9.70154347923003e-5 4.431058525463991e-6 -5.3391445024213556e-5 0.00015518793349847968 -1.9283856661688298e-5 -0.00011526601759023036 -0.00018084109490508688 -8.755505280341406e-5 -5.523436505236192e-5 0.00016565968087171388 0.00018285883686267314; 1.3120560406520929e-5 5.698476045826816e-5 0.00010871873267194422 -2.1199709654064318e-5 0.0001458165558959095 0.0002936207808760722 9.106934769979465e-5 0.0001104864629541301 -2.5014852572080742e-5 7.212169178265385e-5 -4.558708430532087e-5 0.00015375925316212243 9.718806425545147e-5 -6.26939229460644e-5 -6.504690215487287e-5 -1.6191047752810896e-5 -5.9824765476234464e-5 -4.857554565213995e-5 9.410024970735802e-5 2.078145983416907e-5 9.14218169145074e-5 -6.760926887195623e-5 0.00016664380994362596 0.000128787875591911 -8.380447316176959e-6 -7.050670071303495e-5 -7.017923168869569e-5 0.00023534795648618117 8.269342952105355e-6 -0.00011805843066221294 -3.0456997838076656e-5 0.0002738268292249905; -3.806988141873401e-5 0.0001307814915999656 -9.942654040782371e-5 -2.6033584879719085e-5 -6.870246332975821e-5 -6.950897412748636e-5 6.603470633767394e-7 1.1176452710104216e-5 7.220557991972942e-5 -8.454941170198333e-5 5.725489438238346e-6 2.0702120957183474e-5 -0.0001719410259174331 1.583862714328263e-5 3.0482537760733842e-5 7.216072364103483e-5 2.461656068776233e-5 -4.675345670689403e-5 -0.00013234557017125588 0.00012953095821291693 -5.12729323308596e-5 7.00482788676133e-6 -9.775921834075454e-5 -8.589950201707594e-5 -0.00013621850420163304 0.00022585893267391976 0.0001115282600149405 -0.00014347263394296834 -3.400587536772558e-6 0.0001467403296483532 -0.00011278666105494807 -3.427410528088539e-5; 8.558317290037662e-5 -1.7098815309574114e-5 -5.841594452619483e-5 7.929291836801924e-5 -0.00011233767350319116 0.00010960204030737045 2.3558319067823164e-5 -9.686638282262008e-5 9.175133686397119e-5 -3.846885943318905e-5 0.00020403518078267467 -0.0001195650713762662 -7.578001009648898e-5 -5.1420329558888865e-5 -1.2940486389263474e-5 -6.698260453115029e-6 -1.908779472985054e-5 3.295386499842413e-5 -7.199904417035963e-5 -0.00016200130835756227 -6.735079864854222e-5 -0.0001236787231883381 -0.00018851603295304113 -3.173342543420326e-5 -0.0001229870124498769 0.0001359587602072943 0.00015288253575447853 2.5185865293615965e-5 -2.431230337067251e-5 0.00016896652027383215 -1.7740345587362895e-5 0.000177002830010613; 2.7498809424731322e-5 0.0001914993706236483 8.932173724359422e-5 3.639247063333977e-6 0.00011158338812717945 2.8170780490189568e-5 -5.155069383137062e-5 -7.425657466911756e-5 4.477172922060456e-7 0.00011901866738525137 0.00014610470017258827 2.441538955507745e-5 9.039345669633682e-5 -0.0002340310523209258 4.964461978714882e-5 -6.265972414213282e-5 -0.00018630467028516392 -0.00010112211683691526 7.870332111677593e-5 -9.81494201842553e-7 -5.96371603819638e-5 1.9174955662021445e-5 5.560017663127721e-5 -1.4339729466442194e-5 7.437577070276114e-5 0.000254894949387096 3.3935343446943756e-5 2.135571295410482e-5 -0.00011396517883214223 -7.29663363813518e-5 -0.0001813122137605333 -9.879613143085527e-5; -0.00011729975136801955 -1.9606463469573234e-5 -0.00011695924382760573 -0.00014271945161849358 2.7186711518628006e-5 -5.05317225946548e-5 2.101141893548211e-5 8.10052874254577e-5 -9.26540020048926e-5 -5.6068729977022554e-5 0.00013950632950185512 -5.066539284996122e-5 -0.00021210215396820213 7.172105993619476e-6 -0.00010614546823887701 -7.510522897461846e-5 -5.609669512011534e-5 -0.00020618466129633105 -5.514830776822953e-6 0.00019275359175999147 -3.8732036559279856e-5 2.3749982061126284e-5 4.248766286423865e-6 -1.2737827605035376e-5 -8.034203398380141e-5 -6.021567657113426e-5 3.414256883446912e-5 -0.00013460188398707357 -0.00015725544705162196 9.530354821947732e-5 -0.0001145446227059006 -2.976054980933794e-5; 0.00010490547903200934 2.5142995594345695e-5 1.58051716544835e-5 -0.00014304478644307754 -5.1771678033138495e-5 -6.024414868706469e-5 5.062988305421135e-6 -0.0001298753613690626 9.751408835027512e-5 -5.2113287881088373e-5 0.00022345836806662525 6.492433932463031e-5 -8.517521017610666e-5 -4.2747533358937286e-5 5.240030260493403e-5 0.00010253730034774502 0.0001636019023982096 3.7884294227393e-5 -1.4024813924062077e-5 -8.465146491917655e-5 -8.694775714220124e-5 0.00014793480114630856 8.225708382366212e-6 -5.989890086029796e-5 2.6615152831349744e-5 7.631566587688559e-5 -2.8849497527423046e-5 -0.0001806113705123601 9.199161469323885e-5 6.280478738800058e-5 -5.549397791300879e-5 -3.6575840903368896e-6; 8.974595026276431e-5 -4.864165761343605e-5 -0.00020224177946169916 -7.520099110502651e-5 -0.00015423003436823904 -3.768820546637248e-5 -6.679330193972062e-5 8.363385731822495e-6 -1.760848148503178e-5 -0.00012158301628031171 -5.472049464539934e-5 -6.41362677381805e-5 -0.00019255513701890508 -0.0001199748550245992 -4.3182688514467906e-5 8.881964082294425e-5 -0.0001227905342059648 -0.00013052029329654503 -0.00012665565385781445 9.956855124964923e-5 2.308154385874647e-5 -0.00021929010485130035 5.3131323383888505e-5 -6.47724992998061e-5 1.288577992429456e-5 -5.723180232987191e-5 0.00012990121818777967 -9.160973732151317e-6 -0.0001547657922313049 -5.026727389148376e-5 -1.1773872884315656e-5 6.327739161835124e-5; 0.00016429128353616916 -6.639506264922801e-5 0.00013094565710184718 6.515426556711893e-5 -0.00010935029368717992 6.11261426368063e-5 0.00016906214349570215 -1.3629409668044796e-5 -0.0001178898741361468 5.637197400811485e-6 -4.5226231495455547e-5 -3.3449720465234695e-6 0.00013018117223533653 0.0001444427476510669 -2.4648277221547196e-5 -0.00011330425278150688 5.97838812520986e-5 6.277999690633994e-5 6.564414124136168e-5 0.00013470700704623103 -3.010620949656866e-5 -6.003102626869625e-5 9.126620042503684e-5 -7.764470328538128e-6 4.0216139123239346e-5 -0.00021215059624386616 5.64888801729984e-5 4.4030485605707465e-7 4.021944240799875e-5 0.00018086831604729266 -9.956644445059788e-5 -0.00019128227669146987; 5.878713373816193e-5 6.810362275130504e-5 -6.950281648757811e-5 -3.795055353409436e-5 -0.00014261353034388327 0.00023216650616516958 0.00022871406427723797 3.3062545397780936e-5 -4.211225940823182e-5 -0.00012407714065490707 4.708926381566005e-5 9.356892870412761e-5 -0.00012596563000558285 8.426472514541474e-5 2.304735345090394e-5 0.00017001360275456314 8.746045849631587e-6 0.00010187457256391278 -5.2610869616978436e-5 -2.5071060819100862e-5 -4.174807226378816e-5 0.0001077904499731837 3.824113547010017e-5 2.2978879413791848e-5 -4.210775195248919e-5 -0.00010837957496848503 -4.928421084629955e-5 -4.071897537075814e-5 -3.877895768825342e-5 0.00011554456896986476 -0.00019508792583225977 -6.144379121112045e-5; 9.477298571040465e-5 -7.597345040592738e-6 8.088526362961309e-6 -6.241539105062994e-5 0.000295641438958837 -5.289105703228629e-5 -9.943991627766008e-5 5.2521150598875703e-5 -8.958965929521347e-6 -1.8173563341734433e-5 -8.157863359188718e-5 0.00011224852622730265 6.530675435682295e-5 0.00012861650592425382 1.6692474480476365e-5 -9.806388717366792e-5 1.1909017109827216e-5 -0.00013382968628535463 -0.00010908270652394002 9.655438751693678e-5 -4.528001246613901e-5 -0.00011652398117107123 -7.991626645194848e-5 -3.68835864831925e-5 2.986070811643089e-5 -3.747601314203893e-5 -1.2009214708411711e-5 -0.00014456127265689449 2.4632010343010323e-5 -2.890161905485211e-5 -0.00018988667599147008 -2.483074660960723e-6; -2.211513370456331e-5 7.262827906926494e-5 -6.699043538717169e-5 -0.0001478895088966746 1.7588671964460432e-5 -0.00020801246378868187 -5.8760072222853295e-5 -0.00014305968457664566 -1.6123919211151977e-5 4.023772757319682e-5 6.921148209768722e-5 -1.3655441451181415e-5 -3.967569742416309e-5 6.137292561289591e-5 0.0002978554306009852 -4.153586147184675e-5 -7.023587628099332e-5 1.644294692779224e-5 -5.617896263899487e-5 4.281093910508301e-5 8.364634171995076e-5 -0.00011767104538918025 -0.0001339805845946011 -0.00014323659221006257 -8.281100841303172e-5 -5.716432375082405e-5 -0.00012373331506567556 2.194792736375026e-5 -8.291639338311904e-5 -5.222688074171833e-5 0.00020732943570876312 4.836450292897126e-5], bias = [1.9472372754545018e-9, 1.7773607526937544e-9, 5.221018176775004e-9, 3.812312706044044e-11, 3.3611962974319504e-9, -2.2529796583114973e-9, 2.3664357104146897e-9, -3.075449450966708e-9, -1.0155844291358862e-9, 1.9214890777728412e-9, -2.502725337145863e-9, -1.3804557553693939e-9, 2.040885056632904e-9, -3.727604388007969e-9, 6.28961367247158e-10, 3.6675277534707073e-9, 2.4424869728410274e-9, -2.7814599729290868e-9, 1.448721330582054e-9, -1.9014270440556866e-9, -1.868673185906873e-9, 6.3216520272797625e-9, -1.6637277007961234e-9, -5.695747607803651e-10, 7.209678085459496e-10, -5.632104488966317e-9, 1.0059130601703199e-9, -6.741302320635246e-9, 2.183324125291099e-9, 1.0014525581892368e-9, -1.8773902069878824e-9, -3.0073879224747992e-9]), layer_4 = (weight = [-0.00067686551860744 -0.0005368300163628079 -0.0007921738055296399 -0.0006196085037892745 -0.0007244398801363821 -0.0007097497794698365 -0.0006593558512365231 -0.0007416371393926432 -0.0005960887922108718 -0.0007350308875277291 -0.0006156005861764245 -0.0008386894049390347 -0.0005840541633242304 -0.0006029975929283128 -0.0007161310384775486 -0.0005563640575472237 -0.0007283228728292673 -0.0006592820589910498 -0.000854577167622903 -0.0005705475081132089 -0.0006753489018345013 -0.0006530905441083773 -0.0008378781176315651 -0.0007381676075899766 -0.0007651834199543502 -0.0006996916866415245 -0.0006133492045866389 -0.0007554696715538974 -0.0007395449444672793 -0.0008034701429858383 -0.0006324967714830057 -0.000749918086726935; 0.00026837620332870296 0.00011926672040359025 0.00017999991481591564 6.113838624158019e-5 0.0001654572038887778 0.00031892068605082546 0.00011758285876725387 0.00028264574502074116 8.927743303739393e-5 0.0001664788475610874 0.00033937124998177973 0.00030692119664602264 0.00017449657272732442 0.000299569662985531 0.0002730208905574049 0.0002814352559254577 9.42818779895119e-5 0.00012903017507096194 0.00025098623892860344 0.00013478875062514597 0.00027043268553162845 0.00029922924429021857 0.0004410108320833369 0.00023056252238368554 0.00020238835976350335 0.0003760552184170243 0.00016622745148088641 0.0003166273455014743 0.00028710563913283537 0.00034561011783776824 0.0002618493239492187 0.00020548933982249845], bias = [-0.0007092776601112335, 0.0002397118501774389]))
Visualizing the Results
Let us now plot the loss over time
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")
lines!(ax, losses; linewidth=4, alpha=0.75)
scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)
fig
end
Finally let us visualize the results
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
s3 = scatter!(
ax,
tsteps,
waveform_nn_trained;
marker=:circle,
alpha=0.5,
strokewidth=2,
markersize=12,
)
axislegend(
ax,
[[l1, s1], [l2, s2], [l3, s3]],
["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
position=:lb,
)
fig
end
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 128 × AMD EPYC 7502 32-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
JULIA_CPU_THREADS = 16
JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 16
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.