Training a Neural ODE to Model Gravitational Waveforms
This code is adapted from Astroinformatics/ScientificMachineLearning
The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl
Package Imports
using Lux,
ComponentArrays,
LineSearches,
OrdinaryDiffEqLowOrderRK,
Optimization,
OptimizationOptimJL,
Printf,
Random,
SciMLSensitivity
using CairoMakieDefine some Utility Functions
Tip
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector
function one2two(path, m₁, m₂)
M = m₁ + m₂
r₁ = m₂ / M .* path
r₂ = -m₁ / M .* path
return r₁, r₂
endNext we define a function to perform the change of variables:
@views function soln2orbit(soln, model_params=nothing)
@assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
if size(soln, 1) == 2
χ = soln[1, :]
ϕ = soln[2, :]
@assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
p, M, e = model_params
else
χ = soln[1, :]
ϕ = soln[2, :]
p = soln[3, :]
e = soln[4, :]
end
r = p ./ (1 .+ e .* cos.(χ))
x = r .* cos.(ϕ)
y = r .* sin.(ϕ)
orbit = vcat(x', y')
return orbit
endThis function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d_dt(v::AbstractVector, dt)
a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
b = (v[3:end] .- v[1:(end - 2)]) / 2
c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
return [a; b; c] / dt
endThis function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d2_dt2(v::AbstractVector, dt)
a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
return [a; b; c] / (dt^2)
endNow we define a function to compute the trace-free moment tensor from the orbit
function orbit2tensor(orbit, component, mass=1)
x = orbit[1, :]
y = orbit[2, :]
Ixx = x .^ 2
Iyy = y .^ 2
Ixy = x .* y
trace = Ixx .+ Iyy
if component[1] == 1 && component[2] == 1
tmp = Ixx .- trace ./ 3
elseif component[1] == 2 && component[2] == 2
tmp = Iyy .- trace ./ 3
else
tmp = Ixy
end
return mass .* tmp
end
function h_22_quadrupole_components(dt, orbit, component, mass=1)
mtensor = orbit2tensor(orbit, component, mass)
mtensor_ddot = d2_dt2(mtensor, dt)
return 2 * mtensor_ddot
end
function h_22_quadrupole(dt, orbit, mass=1)
h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
return h11, h12, h22
end
function h_22_strain_one_body(dt::T, orbit) where {T}
h11, h12, h22 = h_22_quadrupole(dt, orbit)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
h11 = h11_1 + h11_2
h12 = h12_1 + h12_2
h22 = h22_1 + h22_2
return h11, h12, h22
end
function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
# compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2
@assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"
h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
@assert mass_ratio ≤ 1 "mass_ratio must be <= 1"
@assert mass_ratio ≥ 0 "mass_ratio must be non-negative"
orbit = soln2orbit(soln, model_params)
if mass_ratio > 0
m₂ = inv(T(1) + mass_ratio)
m₁ = mass_ratio * m₂
orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
return waveform
endSimulating the True Model
RelativisticOrbitModel defines system of odes which describes motion of point like particle in Schwarzschild background, uses
where,
function RelativisticOrbitModel(u, (p, M, e), t)
χ, ϕ = u
numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
denom = sqrt((p - 2)^2 - 4 * e^2)
χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = numer / (M * (p^(3 / 2)) * denom)
return [χ̇, ϕ̇]
end
mass_ratio = 0.0 # test particle
u0 = Float64[π, 0.0] # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4) # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize) # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, eLet's simulate the true model and plot the results using OrdinaryDiffEq.jl
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)
axislegend(ax, [[l, s]], ["Waveform Data"])
fig
endDefining a Neural Network Model
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but in case you do use them, make sure to mark them as const.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,
const nn = Chain(
Base.Fix1(fast_activation, cos),
Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-0.000121833524; 5.226328f-5; -0.00014548776; -9.4247895f-5; 6.5666005f-5; 6.0171973f-5; -2.302966f-5; -4.473732f-5; 0.00011161461; -0.0001343411; 0.00018319611; -2.065665f-5; -0.0001455036; -0.00013316625; 3.3724489f-6; 0.00011287227; -0.000121913195; -5.791526f-5; 2.5383932f-5; 5.8675174f-5; 6.919063f-5; 1.2302994f-5; 0.00011310558; -0.00022887558; 6.239341f-5; -8.6017055f-5; 8.1246224f-5; -2.816787f-5; -5.3732307f-5; 2.1526877f-5; -4.7608614f-7; -0.00012797407;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[0.00012098461 -3.1793716f-6 -0.00016925832 0.00018047818 4.7639092f-5 -2.3853965f-5 8.690618f-6 4.9900984f-5 1.6080638f-5 3.9745402f-5 0.00018865638 -0.00015309214 -0.0001393886 8.16777f-5 0.00015142024 -0.0002877161 -1.15114735f-5 -1.3887803f-5 6.5486296f-5 0.00017115867 2.547166f-6 4.475723f-5 3.370674f-5 7.3497926f-5 -2.3236535f-5 -0.00011454187 -5.7773028f-5 -2.1686733f-6 -0.000117024974 6.817756f-5 4.8544516f-5 1.3873367f-5; 4.9419403f-5 -8.8853405f-5 0.00017627765 -9.2374585f-6 -3.6458423f-5 -0.00010909792 -8.840931f-6 0.00019315782 -0.00023656964 -0.00013001727 -4.6604386f-5 5.4964716f-5 4.194813f-5 8.9456305f-5 -2.0390464f-5 4.315848f-5 0.00011068002 -0.000120885496 -2.7457032f-5 8.654813f-5 -6.19375f-5 8.9015f-5 0.00016170327 -1.3180111f-5 -6.828717f-5 -5.235474f-5 -0.00013817029 1.997806f-5 9.589974f-5 0.00011909511 -2.9582805f-5 -0.000117821124; 2.5021352f-6 -3.972822f-5 4.3475924f-5 -4.351588f-5 -9.322682f-5 0.00012971596 -3.8466005f-5 -0.0002057799 0.000113866685 -0.000121541765 4.7833455f-6 -1.8694108f-6 6.8112444f-5 -0.00010590416 0.00010198495 -0.00014101506 -5.6211968f-5 0.00017691097 0.000109353874 5.7574416f-5 5.8173522f-5 7.6723045f-5 -9.940791f-6 -0.00024062177 -3.8415536f-7 -4.7900572f-5 0.0001146954 4.4845667f-5 0.00012490057 -5.0847168f-5 -4.9116938f-5 -3.044836f-5; 0.0002739394 -0.00013988842 1.9670928f-5 -5.8641956f-5 1.4994268f-6 -6.505096f-5 9.619649f-6 -0.0001105538 -8.322841f-6 4.948414f-5 7.1096416f-5 0.00021284555 2.8332284f-5 -0.00019283444 -9.614904f-6 1.9693032f-5 6.488588f-5 2.8474196f-5 -5.396195f-5 5.5872293f-5 7.812282f-5 2.841903f-5 4.291325f-5 0.00016480341 -8.080016f-5 9.2782815f-5 -9.531977f-5 -1.4011667f-6 -0.0001287728 9.0707435f-5 -0.0001104038 -0.0001779106; -0.00023792716 -8.982253f-5 6.70181f-5 -2.4508068f-5 -3.8080998f-5 6.6884066f-5 0.00018756089 -0.00012504817 -6.865935f-5 -4.5984332f-5 -9.6874166f-5 6.2896f-5 0.00012978372 -0.00018000495 0.00014166064 2.0420945f-5 3.8261467f-5 0.00013661456 -1.2250951f-5 0.00017316673 -0.00015296438 2.610997f-5 8.0482976f-5 6.0651666f-5 2.7822087f-5 4.8192137f-6 6.503187f-5 1.2376566f-6 -0.00018387396 -3.2783882f-5 -0.000107452186 6.758489f-5; 4.2901065f-5 4.379356f-5 0.00014984349 -3.418293f-5 2.6760583f-6 -5.6640532f-5 -0.00016193806 -4.385191f-6 0.000117384065 -7.436404f-5 -7.2863f-5 -4.2144537f-5 -5.2375082f-5 2.6375717f-5 -0.00015052734 -6.471826f-5 -0.000112144466 8.368769f-5 -8.4290295f-5 -1.9809591f-5 -0.00024052191 -4.1300707f-5 4.787187f-5 5.423935f-5 -0.00012322601 0.00021362747 -1.43388215f-5 0.00028168765 7.772276f-5 2.7392995f-5 0.00013816627 -2.3387076f-5; 0.00019933254 3.858486f-5 -6.664329f-6 -5.5389974f-5 1.3964216f-5 3.0504829f-5 0.00020237041 -0.000113360155 -2.2153403f-5 9.9686935f-5 0.00011600264 -4.5044013f-5 -8.450763f-6 0.0001412588 7.327027f-5 -2.6356816f-5 -8.1216356f-5 -9.648212f-6 -5.1866165f-5 -4.9239752f-5 -6.4575993f-6 -0.00015997741 0.00014748727 -9.527522f-6 5.1776842f-5 -2.6598744f-5 0.00012829683 0.00018235025 -4.756466f-6 0.00015002387 -9.3117815f-5 -8.551542f-5; -5.176004f-5 3.076354f-5 9.816711f-5 -6.236128f-5 -0.00019677612 -3.37232f-5 -8.745742f-5 -6.0339855f-5 0.00013477103 -3.506862f-5 0.00021175008 -0.000115124036 -0.0001341444 0.00020790416 -0.00012419281 -9.56815f-6 -6.937415f-5 2.5049725f-5 8.8678615f-5 4.6431513f-5 -0.00015388161 -0.00013225916 0.00011604144 1.6400947f-5 -3.938651f-5 0.00012639961 0.00020744458 -1.4476744f-5 0.00010785135 0.00014434729 -4.3718057f-5 6.7858506f-5; 4.13596f-5 7.4824197f-6 -8.019998f-5 -0.00014600964 -7.5371914f-5 0.00013267744 -3.8512422f-5 0.00011815818 -2.4126415f-5 6.986372f-5 -8.782615f-6 -0.0001640415 -0.00010012986 0.00011750665 -0.00022745525 9.21394f-5 5.2887095f-5 9.866504f-5 -3.2957007f-5 -0.00014822208 0.00017615996 0.0001063579 7.165334f-5 0.00018310906 -0.00016610672 -9.8517565f-5 8.6435495f-5 -0.00018733062 6.0782197f-5 -2.9121129f-5 8.24901f-5 -0.00015860566; -2.6852662f-5 5.0577448f-5 -0.00012037723 -4.8278307f-5 1.1415976f-5 5.430896f-5 2.164788f-6 5.0023056f-5 1.8116867f-5 3.182565f-5 -9.876856f-5 0.00015290065 8.620655f-5 0.00017282672 -0.00018612221 -3.109368f-5 -7.5140146f-5 8.1189435f-5 -6.821774f-5 2.0033973f-5 -6.0568414f-5 -2.2967035f-5 -0.00011055566 -0.00014637492 -6.31712f-5 4.3154156f-5 -5.9654445f-5 -0.000119849225 3.6521596f-5 -3.804658f-5 -9.9991325f-5 0.00012410116; 0.000118556294 0.00018987541 -1.9791749f-5 -1.0431683f-5 -3.544049f-5 -5.1498417f-5 0.00010475172 6.8089816f-5 2.6321764f-5 -8.731531f-6 7.166928f-5 2.068046f-5 -7.157632f-5 7.017133f-5 8.780772f-5 8.946799f-6 -5.1332554f-5 -7.082219f-6 -4.3213808f-5 0.000101384445 -5.585334f-5 2.966609f-5 -5.8511232f-5 2.3084063f-5 -5.5145756f-5 0.00015625282 0.00015639832 0.00012072952 -0.00014091987 0.00010839788 -0.00013098048 -6.617249f-5; 8.349264f-5 -2.0313326f-5 -0.00014538634 0.00013645062 1.6249689f-5 -0.00014140893 -0.00015707125 6.337486f-5 -0.00018151256 3.7975657f-5 -4.1208907f-5 5.7036632f-5 -2.3838567f-5 2.0624186f-5 -5.438394f-5 -4.8406127f-5 9.4805815f-5 2.3522066f-6 4.484111f-5 1.5686823f-5 0.00013309681 -0.000126142 -7.3574706f-6 0.00010592853 0.00022496763 6.0129245f-5 -0.0001467932 6.8567743f-7 4.086242f-5 4.3271404f-5 0.0001122748 4.089129f-5; 9.054396f-5 7.603145f-5 -9.916828f-5 -0.00014290806 -5.6543636f-6 0.00012970249 5.1426923f-6 -5.9568345f-5 8.3458246f-5 -5.3231397f-5 -5.48199f-5 9.935899f-5 -8.154001f-5 -0.00010688838 -6.875332f-5 -5.3358897f-5 -8.832044f-5 -6.0976337f-5 -6.9511385f-5 0.00021521635 -3.9070665f-6 -0.00012724209 0.00011884495 -1.6630574f-5 1.2037118f-5 1.791825f-5 -0.00015724775 6.93324f-6 -0.00013000787 7.930971f-5 -1.5113076f-5 0.00011165162; 1.41573855f-5 4.692791f-5 9.490657f-5 5.4507735f-5 3.813516f-5 -3.6407793f-5 -3.483532f-5 -0.00023602015 2.0303567f-5 8.5597945f-5 3.3068376f-5 1.3202371f-5 7.893839f-5 -7.8914934f-5 0.00021176679 5.3839995f-5 0.000115054 0.00012146225 -8.834894f-5 9.35062f-5 -5.053624f-5 3.2499431f-6 1.7776372f-5 1.6601733f-5 -9.680719f-5 -4.3834752f-5 0.0001230333 8.876548f-5 -6.380355f-5 4.706044f-5 -6.294361f-5 -1.0442521f-6; 5.0566864f-6 0.00014824164 -0.00014997872 2.4154366f-5 9.578502f-5 4.667641f-5 -0.00016604921 -1.6517923f-5 -0.00015009103 -3.3604316f-5 6.709687f-5 -0.00013248352 2.400415f-5 3.7012716f-5 8.3898994f-5 -0.00012580639 -8.085388f-5 9.458848f-5 1.5928259f-5 1.3944426f-6 -0.00018008807 -0.00018370894 0.0002067821 -4.12879f-5 0.000182462 4.478851f-5 -7.813366f-5 -0.00012109349 -6.543676f-5 -0.00017203121 -2.1334403f-5 -8.763678f-5; 9.989225f-5 0.0001502882 6.314273f-5 -0.00010726669 -2.720436f-5 -2.153743f-5 -0.0001430121 -0.00013020477 2.1699849f-5 8.13572f-5 -8.213741f-5 2.5950294f-5 -0.00015455997 0.00027970673 1.0076456f-6 -0.00010057043 -4.1876778f-5 -0.00019170379 -0.00017699419 -3.927701f-6 8.4509666f-5 -2.7918053f-5 0.00013301744 -3.2339572f-5 -0.00014433631 0.00013277968 2.3637205f-5 -4.8393693f-5 0.00015155907 -0.0001393549 5.192361f-6 0.00014203414; 7.3316674f-5 5.731449f-5 -1.4807985f-5 -8.87956f-5 0.0001649868 -7.339981f-5 -1.7766599f-5 -8.2668514f-5 -4.9548573f-5 8.835331f-5 -5.015f-5 7.413129f-5 2.251642f-5 -1.5465224f-5 5.3731572f-5 -0.0001254326 -6.8018344f-5 -5.345368f-5 -7.823229f-5 -5.721f-5 4.7523212f-5 -0.000115950505 -2.607217f-5 4.7105277f-5 7.049033f-5 3.1321135f-5 7.010367f-7 6.28337f-5 -0.00015961911 -7.3627794f-5 4.0159324f-5 -0.00014225848; -7.222643f-5 -0.00016379684 -0.00019247628 -1.2092276f-5 -2.4465391f-5 -0.0001474846 -0.00014230137 -0.0001156653 -6.0596172f-5 -1.24984335f-5 -8.264519f-5 3.6915677f-5 1.0720873f-5 -5.038863f-5 -0.000102133235 7.9210666f-5 -5.3792177f-5 -1.1838958f-5 0.000100785284 -0.00018950333 -0.00014778788 -0.0001352631 -4.6875415f-5 -8.655309f-5 -6.539405f-5 9.9534074f-5 4.9282866f-5 -3.9072915f-5 -0.0001795919 4.7024943f-5 7.888468f-5 -0.00014586058; 6.1346465f-5 -6.565402f-5 0.00018427773 -0.00013512117 6.4596614f-5 -6.0913266f-5 0.00011180679 9.808862f-5 -5.4451495f-5 3.5919984f-5 2.9497589f-5 0.0001353019 -8.443911f-5 0.00011229194 -1.0832092f-5 4.263529f-5 1.4503634f-5 2.0462671f-6 -0.00015365852 7.8180565f-6 0.00022848284 -1.0336086f-5 -0.00010773036 -0.00017025125 5.977596f-5 2.7099957f-5 0.00027171097 -2.0272926f-5 -7.850733f-5 4.1249084f-5 -2.260254f-5 0.00013699204; 2.201551f-5 0.000101947764 -1.3533097f-5 4.5069333f-5 7.385084f-5 3.8975722f-6 -0.0002446656 0.00026689176 0.00012953009 6.0405626f-5 6.391655f-5 -2.3103383f-5 5.845412f-6 8.512497f-5 -8.391975f-5 -2.5002013f-5 3.9156566f-5 -3.117346f-5 3.7621438f-5 0.00012283956 1.6659765f-6 0.00021855925 5.763953f-5 0.00019929717 8.494671f-5 -8.401957f-5 -8.47339f-5 0.00017181144 -0.00011680802 7.230197f-6 0.0001716738 3.8542257f-5; -1.37670695f-5 -9.663928f-6 -6.003368f-5 4.6093985f-5 -9.5465155f-5 5.017987f-5 -3.8096998f-5 -0.00010724861 -4.5834135f-5 0.00012484415 2.3732611f-5 6.7478526f-5 -2.9455088f-5 0.00018137283 -7.349032f-5 -7.361027f-5 6.160671f-5 -6.2400427f-6 2.842734f-5 5.6415774f-5 -2.798274f-6 -1.6937187f-5 -8.0031044f-5 -0.00016528771 9.475667f-6 -8.488308f-6 3.9653245f-5 7.937549f-5 -6.944379f-5 1.5236237f-5 0.00011267403 0.00012128895; -0.00017419584 -3.4173652f-5 -0.00012293477 3.4231423f-6 3.0489744f-5 9.652904f-5 0.00019198944 -3.2510125f-5 4.7263693f-5 -5.7305246f-5 6.433572f-5 -7.056776f-5 0.0001446554 -0.00019137186 8.4854335f-5 -1.7600387f-5 -0.00011257039 3.6661346f-5 -5.9924616f-5 5.2464595f-5 -0.000105812716 8.119278f-5 -6.3595144f-5 -3.383891f-5 9.2830735f-5 0.00017736525 -5.7138375f-5 -0.00016214531 -0.00020733538 -8.4069696f-5 2.861184f-5 4.375445f-5; 2.6326405f-5 -5.3000567f-5 7.486468f-5 -4.5020915f-5 0.00010197314 -7.5301825f-5 -1.217346f-5 0.0001330116 -3.3855064f-5 -6.60505f-5 0.00020880345 2.548983f-5 2.3604564f-6 -2.1228494f-5 -0.00014718987 -7.9213576f-5 0.00021136706 -7.9419326f-5 9.5405514f-5 -6.34242f-5 -0.00017324013 5.9023137f-6 9.2702186f-7 -0.000113373135 6.7464325f-6 -9.6387535f-5 -6.1322426f-5 -0.00016411848 0.00015921654 -4.0238872f-5 6.744505f-5 -1.7572334f-5; 0.00013766704 -3.179989f-5 -8.1811064f-5 4.1163614f-5 1.6567418f-6 3.2143955f-6 -0.0001678506 0.00032219628 0.00015908263 3.277248f-5 5.3267224f-5 1.3909649f-5 -1.06071275f-5 -6.6437206f-5 -1.6663917f-5 5.4900098f-5 -2.6091016f-5 -0.00014859336 -1.388853f-5 -0.00011503227 0.00017579479 8.59001f-5 1.6323997f-5 9.890471f-5 -3.973207f-5 2.104771f-5 -1.9911795f-5 -4.546562f-5 7.140483f-5 4.6305773f-5 -1.6572103f-5 -0.00012274306; -0.00011224545 -2.621561f-5 -0.00017563188 -3.5770667f-5 2.9284243f-5 -7.282488f-5 2.7839824f-5 0.00015262378 -3.5693523f-5 -0.00014342621 -2.3579209f-5 6.417864f-5 -9.841309f-6 -0.00015152291 -8.997487f-6 -4.2796397f-5 3.5923615f-5 0.00011743984 0.00013756075 -1.0368818f-5 6.2818566f-5 1.7841692f-5 -7.8469726f-5 3.5093934f-5 0.00017766547 0.000105008396 7.4033596f-5 1.5816251f-5 0.00021932441 6.444105f-5 0.0001149696 -9.3690775f-5; 9.390687f-5 2.8817527f-5 2.5798822f-6 2.9104258f-5 8.431707f-5 -0.00015002012 4.8036967f-5 -0.000103617014 -3.9822688f-5 4.7222733f-5 2.4699606f-5 -6.951308f-5 5.2539497f-5 -6.128821f-5 4.7356913f-5 -1.6324238f-5 -0.00013462668 1.34276615f-5 -4.3995336f-5 -8.94021f-5 0.000115877694 -6.0422084f-5 -0.00014238966 -2.7973496f-5 6.1959465f-5 -2.2907701f-5 -1.9990286f-5 0.00013034731 -0.00011006391 -9.603271f-6 -0.00010462243 -1.5545793f-6; 8.062468f-5 -4.2360323f-5 1.7561892f-5 8.393393f-5 0.00015709623 8.183535f-5 -0.00013649077 -0.00012413497 4.602098f-5 -7.764947f-5 -7.094609f-5 1.0658274f-5 -2.959047f-5 -1.7481376f-5 6.0271344f-5 -7.5923315f-5 3.4294942f-5 1.96222f-5 -6.052716f-5 -8.948569f-5 -5.8697893f-5 -6.406193f-5 -2.0929248f-5 1.096763f-5 4.8788574f-5 -8.4707295f-5 5.0446222f-5 0.00012770953 -9.476659f-5 0.0001229799 -4.7511487f-5 -5.9715047f-5; -1.0178862f-5 7.3337855f-5 -2.5639087f-5 0.00018021169 -0.00015388745 -0.00016907648 -1.0999256f-5 -2.7758468f-5 5.8770212f-5 -6.497853f-5 -0.00010794045 -5.6342495f-5 -0.0001440074 0.00019699364 0.00019142384 -2.0811793f-5 -0.00012644348 -0.0001031679 2.6509444f-5 3.5410056f-5 3.473423f-5 0.000105799896 -0.00019377902 0.00010511171 -5.4991862f-5 0.00029245164 -8.931673f-5 0.00010792144 8.4002684f-5 -2.3492994f-5 -0.0001768825 0.00014051606; -3.1999287f-5 0.0001589123 -7.964392f-6 6.990975f-6 -2.5815145f-5 -0.00016294172 -1.820594f-5 -0.00018865314 -9.398425f-6 -3.7620794f-5 -9.5399184f-5 -0.00011138453 -9.280464f-5 -0.00020220577 2.696095f-5 0.00018445808 -7.674632f-5 -0.00017210838 -0.00010171319 6.396511f-5 -0.0002505512 2.9001494f-5 -9.108436f-5 0.00013966288 3.3753502f-5 5.329579f-5 -1.0874329f-6 -0.00017917764 9.5715695f-5 -3.960113f-5 -0.00016766136 -2.348212f-5; -3.275683f-5 7.431519f-5 -0.00018415571 9.620622f-6 -7.77141f-5 -0.00021591026 -7.5927936f-5 -8.623112f-5 -3.1932006f-5 2.9730278f-5 8.130606f-6 -2.2231518f-5 -9.6668744f-5 1.9113155f-5 8.840355f-5 6.422152f-5 8.5797845f-5 0.00011269359 0.00014150617 8.0359176f-5 1.7949633f-5 0.00013993705 -9.272845f-5 -3.4993845f-5 1.0332436f-5 -0.00011463174 0.00012920162 8.827542f-5 -0.00012735795 8.8723085f-5 -2.1776566f-5 -0.00014436342; 3.877176f-5 0.00017321152 2.5268831f-5 9.085999f-5 4.3523795f-7 -7.128913f-5 3.0782066f-5 -8.8166234f-5 0.0001516664 0.00016703816 -5.544814f-5 -0.0001144335 -6.2272855f-5 7.0821116f-5 0.00016659602 3.247815f-5 7.181978f-5 -0.00011396051 -1.563216f-6 0.000119876735 -5.1921284f-5 -1.1706596f-5 2.6589027f-5 0.000106274565 -0.00011658598 6.409472f-5 -0.00019067027 -7.996467f-5 -0.0001549154 -6.809827f-5 3.483791f-5 -4.262303f-5; 9.1466774f-8 -0.0001699263 4.652093f-5 7.685748f-6 0.000115453564 7.0158226f-6 7.417135f-5 5.286716f-5 7.069997f-5 0.00019550894 -2.6489204f-5 6.037101f-5 -4.768312f-5 1.4106743f-5 -0.000321441 1.50782125f-5 4.609994f-5 -0.00013025381 0.00013802547 7.115038f-5 5.9167273f-6 0.00011504012 -7.0393835f-5 3.920842f-5 -1.8795992f-5 -2.1412621f-5 1.5771775f-5 2.7853648f-5 -4.5204994f-5 -0.00018526666 7.925317f-5 -3.9292413f-6], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[-1.9511695f-5 -0.00022823054 -8.655644f-6 -0.00017397311 -2.1114156f-5 -6.878444f-5 6.6787296f-5 -9.506661f-5 3.357921f-5 0.000119168784 -0.00014443895 -8.4686195f-5 -8.264408f-5 -0.00010740179 5.4650267f-5 -6.3918844f-5 1.6851604f-5 3.3557803f-6 -0.00024130766 7.68515f-5 8.421087f-5 -0.00010970868 -0.00016246928 2.176999f-5 0.00016096362 -5.6523724f-5 -0.00012536129 -0.00018377756 7.910392f-5 -9.135885f-5 -3.6920348f-5 -5.6634708f-5; -3.0295498f-5 4.4108157f-5 0.000118441574 3.9251823f-5 -0.00011753446 0.00018128163 2.1919685f-5 0.0001351332 0.00013913927 -7.3695446f-5 0.00017097435 0.00020229058 0.00014949165 -7.5894786f-5 -4.9333215f-5 2.2727245f-5 6.555802f-5 0.0001533251 -3.6385827f-5 -9.691387f-5 -3.4643497f-5 0.00012379723 -1.336216f-5 -0.00010422946 -5.489831f-5 -1.3416415f-5 8.7267865f-5 1.55002f-5 -4.27016f-6 -0.00017338293 -0.0001361506 -0.0001305795], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))Similar to most deep learning frameworks, Lux defaults to using Float32. However, in this case we need Float64
const params = ComponentArray(f64(ps))
const nn_model = StatefulLuxLayer(nn, nothing, st)StatefulLuxLayer{Val{true}()}(
Chain(
layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
layer_2 = Dense(1 => 32, cos), # 64 parameters
layer_3 = Dense(32 => 32, cos), # 1_056 parameters
layer_4 = Dense(32 => 2), # 66 parameters
),
) # Total: 1_186 parameters,
# plus 0 states.Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses
where,
function ODE_model(u, nn_params, t)
χ, ϕ = u
p, M, e = ode_model_params
# In this example we know that `st` is am empty NamedTuple hence we can safely ignore
# it, however, in general, we should use `st` to store the state of the neural network.
y = 1 .+ nn_model([first(u)], nn_params)
numer = (1 + e * cos(χ))^2
denom = M * (p^(3 / 2))
χ̇ = (numer / denom) * y[1]
ϕ̇ = (numer / denom) * y[2]
return [χ̇, ϕ̇]
endLet us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
axislegend(
ax,
[[l1, s1], [l2, s2]],
["Waveform Data", "Waveform Neural Net (Untrained)"];
position=:lb,
)
fig
endSetting Up for Training the Neural Network
Next, we define the objective (loss) function to be minimized when training the neural differential equations.
const mseloss = MSELoss()
function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
return mseloss(pred_waveform, waveform)
endWarmup the loss function
loss(params)0.00066574477861615Now let us define a callback function to store the loss over time
const losses = Float64[]
function callback(θ, l)
push!(losses, l)
@printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
return false
endTraining the Neural Network
Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
optprob,
BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
callback,
maxiters=1000,
)retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-0.00012183352373527359; 5.2263279940221405e-5; -0.00014548776380233388; -9.424789459440731e-5; 6.566600495716713e-5; 6.01719730184056e-5; -2.302965913257806e-5; -4.4737320422327853e-5; 0.000111614608613203; -0.0001343410986009033; 0.00018319611262970862; -2.0656649212462025e-5; -0.00014550359628632983; -0.0001331662497248459; 3.3724488730495716e-6; 0.0001128722724388938; -0.00012191319547115663; -5.791525836683182e-5; 2.538393164286155e-5; 5.8675173931916026e-5; 6.919063162044593e-5; 1.2302994036855851e-5; 0.00011310558329564077; -0.00022887557861376017; 6.239341018953164e-5; -8.601705485494407e-5; 8.124622399912329e-5; -2.8167869459082656e-5; -5.373230669645255e-5; 2.1526877389964857e-5; -4.7608614295287893e-7; -0.00012797406816380163;;], bias = [-1.4611921765804536e-16, 1.405085072960725e-17, -2.0970295297675375e-16, -4.269829923994799e-17, -1.4772124370410614e-18, 3.94295308554932e-17, -1.2638924706180172e-17, -1.1177931263413346e-16, 1.1865509309288578e-16, -1.0780498888415572e-16, 2.774634023014732e-16, -8.3749062006209e-18, 4.5805789170468063e-17, -2.1292553190968523e-16, 1.944950411380796e-18, -6.41029668131954e-17, -1.3723958107570282e-16, -1.712343471384874e-17, 2.4409822423009022e-18, 8.031095994553177e-17, 1.5687737107390968e-16, 1.854854629508834e-17, 5.623402277239797e-17, -2.143245642134635e-16, 1.1981360878871646e-17, 5.745756422297649e-17, 7.289141647501363e-17, -6.320461665192035e-17, -2.3457913380640443e-17, 2.589678808229584e-17, -2.907380498520423e-19, -1.0328245683315882e-16]), layer_3 = (weight = [0.0001209855919779289 -3.178388239484936e-6 -0.00016925733400685806 0.0001804791619843833 4.764007569347019e-5 -2.385298178242054e-5 8.691601736336992e-6 4.990196718890736e-5 1.6081620900055363e-5 3.974638563985942e-5 0.00018865736744654805 -0.00015309116027028663 -0.00013938762169974018 8.167868615882485e-5 0.0001514212282682827 -0.0002877151189994981 -1.151049020964313e-5 -1.3886819785421171e-5 6.548727940125951e-5 0.0001711596498108516 2.5481494259102896e-6 4.4758214401645764e-5 3.370772280786358e-5 7.34989090897965e-5 -2.3235551295238637e-5 -0.0001145408885594423 -5.7772044626200605e-5 -2.1676899695189575e-6 -0.0001170239909942213 6.817853970742419e-5 4.8545499496959375e-5 1.3874350018427356e-5; 4.941945309753643e-5 -8.885335471943819e-5 0.0001762777022295921 -9.237408158631239e-6 -3.645837267042939e-5 -0.00010909787004577795 -8.840880292089226e-6 0.00019315786568683708 -0.00023656959154407374 -0.00013001722316487586 -4.6604335403506756e-5 5.496476598552565e-5 4.1948179677157754e-5 8.94563557440705e-5 -2.0390413447719565e-5 4.31585315882985e-5 0.00011068006944612713 -0.00012088544524970332 -2.7456981734692627e-5 8.654817731355449e-5 -6.193744837303515e-5 8.901504708689774e-5 0.0001617033188441128 -1.3180060414859987e-5 -6.828711837527211e-5 -5.235468841504181e-5 -0.00013817023926178797 1.9978110975964418e-5 9.589979101208172e-5 0.00011909515826034496 -2.9582754615082903e-5 -0.0001178210738370836; 2.5019708503714492e-6 -3.972838449741239e-5 4.347575967798582e-5 -4.351604430454275e-5 -9.322698743531929e-5 0.00012971579716790016 -3.8466169560398e-5 -0.00020578006361362053 0.00011386652048926395 -0.00012154192946369203 4.783181150142916e-6 -1.869575119290706e-6 6.811227948886917e-5 -0.00010590432494396721 0.00010198478711785275 -0.00014101522238675908 -5.621213195596282e-5 0.0001769108014671832 0.00010935370964247264 5.7574251779751055e-5 5.817335776768179e-5 7.672288056013634e-5 -9.940955085458068e-6 -0.00024062193252362502 -3.8431970907425207e-7 -4.790073645168416e-5 0.0001146952375096039 4.4845502354600465e-5 0.00012490040834776603 -5.084733193840368e-5 -4.9117101941793275e-5 -3.0448524473886115e-5; 0.00027393972275626527 -0.00013988809361740204 1.9671253976636446e-5 -5.864162951161407e-5 1.499752879907309e-6 -6.505063494819414e-5 9.619975098984536e-6 -0.00011055347706110396 -8.322514549720222e-6 4.9484466653891116e-5 7.109674195839506e-5 0.0002128458786321508 2.8332610309231906e-5 -0.00019283411667834384 -9.614578193350885e-6 1.969335833586923e-5 6.488620281765852e-5 2.8474522405530244e-5 -5.39616244348062e-5 5.587261919595106e-5 7.81231487783303e-5 2.8419356094899844e-5 4.291357479837427e-5 0.00016480373458784285 -8.079983082984775e-5 9.278314108109797e-5 -9.53194408064079e-5 -1.4008407058563153e-6 -0.00012877246765106173 9.070776148314425e-5 -0.00011040347591892975 -0.00017791026904865503; -0.00023792729105457739 -8.982266361249736e-5 6.701797330629033e-5 -2.4508198366450202e-5 -3.808112804086065e-5 6.688393561512111e-5 0.00018756075671285256 -0.00012504830290686468 -6.865947860183279e-5 -4.598446237628408e-5 -9.68742961774766e-5 6.28958668314652e-5 0.0001297835851075868 -0.00018000508051022951 0.00014166050907101345 2.0420814435280966e-5 3.8261336994989977e-5 0.00013661442873876418 -1.2251081061324974e-5 0.00017316659879747654 -0.00015296450790606557 2.6109840219075366e-5 8.04828459151826e-5 6.065153585973279e-5 2.782195671920545e-5 4.819083569825275e-6 6.503173874102424e-5 1.237526517112982e-6 -0.0001838740855234874 -3.278401261900685e-5 -0.00010745231618762154 6.75847631284774e-5; 4.290084752791993e-5 4.379334195472646e-5 0.00014984326908087838 -3.418314672818018e-5 2.6758406860539258e-6 -5.66407500642596e-5 -0.00016193828132746415 -4.385408648529648e-6 0.00011738384700801636 -7.43642588546268e-5 -7.286321424690442e-5 -4.2144754208136775e-5 -5.237529994847708e-5 2.637549978170996e-5 -0.00015052755707220785 -6.471847818975851e-5 -0.00011214468331818969 8.368747576095514e-5 -8.429051313360823e-5 -1.9809808965517427e-5 -0.00024052213057221486 -4.130092502383187e-5 4.787165080269285e-5 5.423913234917215e-5 -0.00012322622869906594 0.0002136272530136854 -1.4339039172780233e-5 0.0002816874336516704 7.772254384544093e-5 2.7392776881599366e-5 0.00013816605560204335 -2.3387293642014497e-5; 0.00019933501108515837 3.858733044671694e-5 -6.661859960346239e-6 -5.5387504734431113e-5 1.3966684956357092e-5 3.050729766164237e-5 0.0002023728834601589 -0.0001133576856673237 -2.2150933990822887e-5 9.968940367888237e-5 0.00011600510562036385 -4.5041544116427605e-5 -8.448294001042135e-6 0.00014126127144753532 7.327273716419088e-5 -2.635434747299646e-5 -8.121388735981828e-5 -9.645742910834455e-6 -5.186369661623577e-5 -4.9237283282309804e-5 -6.455130450378112e-6 -0.00015997494255225102 0.00014748973854783917 -9.525052963948363e-6 5.1779310989331055e-5 -2.6596274882653194e-5 0.00012829929847702872 0.00018235272228685228 -4.753997008265279e-6 0.00015002634109896658 -9.311534585230898e-5 -8.55129523997783e-5; -5.175915939960097e-5 3.076441716642067e-5 9.816798992344527e-5 -6.236040062846846e-5 -0.00019677524105988538 -3.3722322408511144e-5 -8.745653904385852e-5 -6.0338975894199995e-5 0.00013477191386866984 -3.506774155882679e-5 0.00021175096335011525 -0.00011512315686902016 -0.00013414352143060999 0.00020790503767441476 -0.00012419193226721108 -9.567270902854989e-6 -6.937327416723994e-5 2.5050604014023445e-5 8.867949363432834e-5 4.643239153322691e-5 -0.00015388073516421126 -0.00013225827715702328 0.0001160423184012215 1.6401825993139403e-5 -3.9385629501129415e-5 0.0001264004880760964 0.00020744545908767565 -1.4475864740023351e-5 0.00010785223130026775 0.00014434816548258083 -4.371717805385563e-5 6.785938457708361e-5; 4.1358963598822045e-5 7.481784350299129e-6 -8.020061305644904e-5 -0.00014601027452075158 -7.537254951815454e-5 0.0001326768009582034 -3.8513057561735064e-5 0.0001181575475035663 -2.4127050537873756e-5 6.986308780112255e-5 -8.783250096099978e-6 -0.00016404213475683182 -0.00010013049414767866 0.00011750601460304987 -0.00022745588886239986 9.21387667309943e-5 5.2886459697828436e-5 9.866440576812832e-5 -3.295764200659876e-5 -0.00014822271860824316 0.0001761593206444257 0.00010635726120168857 7.165270416377889e-5 0.00018310842769067465 -0.0001661073571179562 -9.851820014803936e-5 8.643485979488769e-5 -0.00018733125271895373 6.0781561287209514e-5 -2.9121764428358908e-5 8.248946450263772e-5 -0.00015860629592710258; -2.6854197959942317e-5 5.057591178086869e-5 -0.00012037876959907561 -4.8279843132657654e-5 1.1414439621548042e-5 5.430742302309116e-5 2.163251873241373e-6 5.0021520190455606e-5 1.8115331011762093e-5 3.1824113728768345e-5 -9.87700995315617e-5 0.00015289911888288502 8.620501383888797e-5 0.00017282518737233335 -0.000186123747800196 -3.109521722904562e-5 -7.514168183802676e-5 8.118789909389183e-5 -6.821927755623877e-5 2.0032436703730754e-5 -6.056995053910267e-5 -2.2968571021522707e-5 -0.0001105571945222948 -0.00014637645737051089 -6.317273883865806e-5 4.3152619662979155e-5 -5.965598112339611e-5 -0.000119850760630982 3.652006042870051e-5 -3.804811504920774e-5 -9.999286058842988e-5 0.00012409962718124527; 0.00011855843393271653 0.00018987755236589133 -1.978960910498279e-5 -1.042954347218337e-5 -3.5438351296813716e-5 -5.149627697397303e-5 0.00010475386141474295 6.809195536329246e-5 2.6323904130491746e-5 -8.729391374861707e-6 7.167142091922874e-5 2.0682599533707283e-5 -7.157418378489673e-5 7.017346859350461e-5 8.780985870539185e-5 8.948939175694362e-6 -5.1330414244208e-5 -7.080079121857691e-6 -4.321166790764848e-5 0.00010138658461068492 -5.5851199540536734e-5 2.9668228926262466e-5 -5.850909230508834e-5 2.3086203036569273e-5 -5.514361630053558e-5 0.00015625495932903528 0.0001564004639294043 0.00012073166060867638 -0.0001409177342066471 0.00010840001839819306 -0.00013097833954821365 -6.617035189271217e-5; 8.349343381181958e-5 -2.0312531597442178e-5 -0.00014538554267433973 0.0001364514112711214 1.625048327609371e-5 -0.00014140814044454175 -0.00015707045411632358 6.337565111520702e-5 -0.00018151176362240026 3.7976450875285583e-5 -4.120811234403291e-5 5.7037426194297576e-5 -2.3837773079285726e-5 2.062497998762467e-5 -5.438314718135579e-5 -4.8405333182889036e-5 9.480660930334258e-5 2.3530008721033867e-6 4.4841902598860114e-5 1.5687617014050113e-5 0.00013309760226461035 -0.00012614120230951297 -7.35667634304078e-6 0.00010592932205238928 0.00022496842695458747 6.0130039236395154e-5 -0.00014679240728669919 6.8647171387554e-7 4.086321252456844e-5 4.327219841711265e-5 0.0001122755941824341 4.089208352437801e-5; 9.054280638244979e-5 7.603029774051677e-5 -9.916943560803956e-5 -0.00014290921132763097 -5.655516425604037e-6 0.00012970133359792264 5.141539467302784e-6 -5.956949787199005e-5 8.345709323200352e-5 -5.323254988163544e-5 -5.482105149987334e-5 9.935784003479837e-5 -8.15411654648622e-5 -0.00010688953222690929 -6.875447031328039e-5 -5.336005012488959e-5 -8.832159398573405e-5 -6.097749025778302e-5 -6.951253778518671e-5 0.00021521519777613305 -3.908219306167305e-6 -0.00012724324512512932 0.0001188438004000761 -1.6631727313333492e-5 1.2035964963588441e-5 1.7917096946777278e-5 -0.00015724890142430819 6.932087152534897e-6 -0.00013000902578917737 7.930856047912e-5 -1.5114229036982203e-5 0.0001116504685605607; 1.415931665275911e-5 4.69298423023059e-5 9.490850361641298e-5 5.4509665734326616e-5 3.813709127014431e-5 -3.640586210069274e-5 -3.483338941707109e-5 -0.0002360182158544559 2.03054981294123e-5 8.559987627547231e-5 3.30703072004854e-5 1.3204301741171494e-5 7.894032323844533e-5 -7.89130032889543e-5 0.00021176872114763795 5.3841925638218935e-5 0.00011505592856435303 0.0001214641845773066 -8.834700993149382e-5 9.350812915825268e-5 -5.053430931995724e-5 3.2518742510375257e-6 1.777830339148969e-5 1.660366369928405e-5 -9.680525972628078e-5 -4.3832821301224495e-5 0.00012303523117323191 8.876741349111461e-5 -6.380161785175074e-5 4.706237023227088e-5 -6.294167969728623e-5 -1.0423210132376576e-6; 5.0541822247601466e-6 0.00014823913681994937 -0.00014998122185258745 2.4151861566694835e-5 9.578251649601333e-5 4.667390624036904e-5 -0.00016605171674654247 -1.652042681142927e-5 -0.00015009353353432706 -3.3606820011277335e-5 6.709436877375922e-5 -0.00013248602169823272 2.4001645783781152e-5 3.7010212167865694e-5 8.389649030960696e-5 -0.0001258088899476018 -8.08563867167117e-5 9.458597799040248e-5 1.5925754837562567e-5 1.3919384400764213e-6 -0.00018009057509969933 -0.00018371144085448625 0.00020677960003469774 -4.129040587481e-5 0.00018245949343055395 4.4786006242262323e-5 -7.813616082329171e-5 -0.00012109599754238838 -6.543926554057914e-5 -0.00017203371805861957 -2.133690702520712e-5 -8.763928178721117e-5; 9.989165372013562e-5 0.00015028759915645452 6.3142137863055e-5 -0.00010726728689127268 -2.7204953964373424e-5 -2.15380234448635e-5 -0.00014301269876972595 -0.00013020536900234205 2.1699254633006022e-5 8.135660280737605e-5 -8.21380064889428e-5 2.5949699905195122e-5 -0.00015456056027409045 0.0002797061340816649 1.0070514536525618e-6 -0.00010057102664808333 -4.187737178375493e-5 -0.0001917043821021612 -0.00017699478232038028 -3.928295170171219e-6 8.4509071514786e-5 -2.791864718978127e-5 0.00013301684774256721 -3.234016659637668e-5 -0.00014433690850359146 0.0001327790839996499 2.3636611134007683e-5 -4.8394286947714245e-5 0.00015155847612100874 -0.0001393554968826176 5.1917670889995065e-6 0.00014203354911211425; 7.331512222782476e-5 5.731293766194111e-5 -1.4809537441291915e-5 -8.879714970925425e-5 0.0001649852541789243 -7.340136164083942e-5 -1.77681509910107e-5 -8.267006580235385e-5 -4.9550125021314046e-5 8.835175447649252e-5 -5.015155203971301e-5 7.412973844230147e-5 2.2514868197791215e-5 -1.5466776511572233e-5 5.3730019664093354e-5 -0.00012543415743374606 -6.801989604235858e-5 -5.345523333625747e-5 -7.823384168541075e-5 -5.721155195596503e-5 4.7521659672662674e-5 -0.00011595205671137239 -2.6073722091213435e-5 4.710372502934916e-5 7.04887764925897e-5 3.131958272325198e-5 6.994845364783354e-7 6.283214805198826e-5 -0.00015962066628953344 -7.362934649672385e-5 4.0157772110709416e-5 -0.0001422600276952718; -7.22324064237893e-5 -0.00016380281409985773 -0.00019248225250083546 -1.209825201404868e-5 -2.4471367182941977e-5 -0.00014749058279014406 -0.00014230735065687332 -0.00011567127764030891 -6.060214833742364e-5 -1.2504409610475076e-5 -8.265116299767951e-5 3.690970082495455e-5 1.0714897357334286e-5 -5.0394605537068844e-5 -0.00010213921156285318 7.92046896411169e-5 -5.379815298983716e-5 -1.1844934092753442e-5 0.00010077930822696674 -0.00018950931077162982 -0.00014779385925541618 -0.0001352690694741156 -4.688139121317596e-5 -8.65590652803055e-5 -6.54000275655759e-5 9.952809817585683e-5 4.927688972637423e-5 -3.907889073637148e-5 -0.00017959787034235315 4.701896726580437e-5 7.887870491210383e-5 -0.000145866559946803; 6.134870543106211e-5 -6.56517813963845e-5 0.00018427996744362857 -0.00013511892851014436 6.45988538694527e-5 -6.0911025711478526e-5 0.0001118090285403641 9.809085801904781e-5 -5.444925503028632e-5 3.59222239880681e-5 2.9499828970476593e-5 0.0001353041434144589 -8.443687059256962e-5 0.00011229418211810759 -1.0829851843964777e-5 4.263752905031571e-5 1.4505873851210159e-5 2.048507185888117e-6 -0.0001536562786255046 7.820296532796969e-6 0.00022848507703244244 -1.0333846175427188e-5 -0.0001077281204102664 -0.00017024901265319645 5.977820089398505e-5 2.7102197209809144e-5 0.0002717132146885561 -2.027068606324179e-5 -7.850508793993227e-5 4.12513245196069e-5 -2.2600300353161058e-5 0.00013699427560864583; 2.2020021725332466e-5 0.00010195227508202383 -1.3528586413601195e-5 4.5073844313167416e-5 7.385534984759946e-5 3.902083272715681e-6 -0.00024466108886289834 0.00026689626968700735 0.00012953460093821108 6.041013687616037e-5 6.392106291137067e-5 -2.309887174223493e-5 5.849923064372946e-6 8.512948255054522e-5 -8.391524176576524e-5 -2.499750195885414e-5 3.916107737575944e-5 -3.1168947379318024e-5 3.762594854461521e-5 0.0001228440668736779 1.6704875005409617e-6 0.0002185637567262408 5.76440397079627e-5 0.00019930167757668018 8.495122158901787e-5 -8.401506062828496e-5 -8.472939231897148e-5 0.00017181594733383976 -0.00011680350878041072 7.23470787389139e-6 0.000171678315319611 3.854676827439173e-5; -1.3766648169278455e-5 -9.663506477973926e-6 -6.0033257557391876e-5 4.609440589991031e-5 -9.54647336926621e-5 5.018029265773187e-5 -3.809657640669748e-5 -0.00010724819070457754 -4.583371331829993e-5 0.0001248445762597627 2.3733032560243697e-5 6.747894736907192e-5 -2.9454666889818138e-5 0.00018137325174609644 -7.348990105821998e-5 -7.36098452194915e-5 6.160713315910212e-5 -6.23962135943759e-6 2.8427760537287472e-5 5.641619578671349e-5 -2.7978526192067308e-6 -1.6936765621376836e-5 -8.003062315549188e-5 -0.00016528728950800404 9.476088291538689e-6 -8.487886350505258e-6 3.965366637419575e-5 7.937590931976889e-5 -6.944336995513465e-5 1.5236658566554351e-5 0.00011267445303235887 0.0001212893687464902; -0.00017419719316694033 -3.417500503291991e-5 -0.0001229361218183747 3.421789494499734e-6 3.0488391154905907e-5 9.652768547948805e-5 0.00019198808930650566 -3.2511477377737845e-5 4.72623401829023e-5 -5.730659851268804e-5 6.433436700446349e-5 -7.056911218842279e-5 0.00014465404145722478 -0.00019137321161269162 8.485298220602991e-5 -1.760173964884771e-5 -0.00011257174575165923 3.6659993008471724e-5 -5.992596871965342e-5 5.246324197924111e-5 -0.0001058140691502538 8.11914292962683e-5 -6.35964965231268e-5 -3.384026368892101e-5 9.282938170781725e-5 0.00017736389471520266 -5.7139728062790614e-5 -0.00016214666485484466 -0.00020733673537779278 -8.407104852540204e-5 2.8610486503566693e-5 4.375309853185888e-5; 2.632567337875957e-5 -5.300129852096137e-5 7.486394678288062e-5 -4.5021646666120725e-5 0.000101972411347539 -7.530255607982663e-5 -1.2174191407234059e-5 0.00013301087527813787 -3.385579476436886e-5 -6.605123418104987e-5 0.0002088027213076058 2.5489098324192544e-5 2.359725148363309e-6 -2.122922524997193e-5 -0.00014719060346671963 -7.921430734406505e-5 0.00021136633221338275 -7.942005687348115e-5 9.54047827590107e-5 -6.342493362446816e-5 -0.00017324086003193173 5.901582416277404e-6 9.26290614623855e-7 -0.00011337386605283509 6.745701274645735e-6 -9.638826669381409e-5 -6.13231568522887e-5 -0.0001641192102497593 0.00015921580415031354 -4.023960355973443e-5 6.744431810964247e-5 -1.75730656437921e-5; 0.00013766845454438363 -3.179847540240407e-5 -8.180965001020164e-5 4.116502764145889e-5 1.6581556560657214e-6 3.2158093740755703e-6 -0.0001678491802981546 0.0003221976972322481 0.00015908404252468493 3.2773894917681816e-5 5.326863770232387e-5 1.3911063023905403e-5 -1.0605713624046675e-5 -6.643579167803175e-5 -1.6662502685440847e-5 5.490151174808723e-5 -2.6089602625232597e-5 -0.00014859194693644232 -1.3887115938351922e-5 -0.00011503085756930227 0.00017579620412062156 8.59015149667438e-5 1.6325410387710275e-5 9.890612504885456e-5 -3.973065527179004e-5 2.1049123960897534e-5 -1.991038102505112e-5 -4.546420445921416e-5 7.140624306840212e-5 4.630718699400743e-5 -1.6570689195296863e-5 -0.00012274164823838807; -0.00011224393710180379 -2.6214098389377588e-5 -0.00017563036998653357 -3.576915513875798e-5 2.928575487132036e-5 -7.282336627284643e-5 2.784133540373332e-5 0.00015262528723652906 -3.569201179815534e-5 -0.00014342469693546533 -2.3577696984649494e-5 6.418015073833563e-5 -9.839797704257431e-6 -0.00015152139712044374 -8.995975795899833e-6 -4.27948852946339e-5 3.592512622090794e-5 0.00011744135373120561 0.00013756226508959693 -1.0367306450266512e-5 6.282007780938702e-5 1.7843203758472194e-5 -7.846821426100192e-5 3.5095445136165546e-5 0.00017766697966109366 0.00010500990721012635 7.403510790865143e-5 1.581776269560238e-5 0.00021932592609047346 6.444255814969113e-5 0.00011497111518948868 -9.368926293135889e-5; 9.390543322346191e-5 2.8816089743971013e-5 2.5784449831392852e-6 2.910282068163363e-5 8.431563377647188e-5 -0.00015002155510907564 4.8035529962846106e-5 -0.00010361845148460591 -3.982412477804439e-5 4.722129573614505e-5 2.469816876377872e-5 -6.951451747178755e-5 5.253805988153758e-5 -6.128964401746948e-5 4.7355475308478594e-5 -1.6325675696372847e-5 -0.00013462811717496233 1.342622430505179e-5 -4.39967736924411e-5 -8.940353678505451e-5 0.00011587625680525337 -6.0423521299042894e-5 -0.0001423910982880097 -2.7974933127044668e-5 6.195802756386314e-5 -2.2909138587660652e-5 -1.9991723328953546e-5 0.00013034587367704434 -0.00011006535010844078 -9.604708120121628e-6 -0.00010462386513560754 -1.5560165224634359e-6; 8.062401708634026e-5 -4.236098878715065e-5 1.7561226072251445e-5 8.393326085246425e-5 0.0001570955685242002 8.183468550163463e-5 -0.00013649143123581 -0.00012413563284604704 4.602031573489642e-5 -7.765013627755e-5 -7.094675287185376e-5 1.0657608313011725e-5 -2.9591135658444258e-5 -1.7482041344517077e-5 6.027067889994073e-5 -7.592398081720466e-5 3.429427694249353e-5 1.962153534299364e-5 -6.0527825312930314e-5 -8.948635853408699e-5 -5.869855861925352e-5 -6.406259829924375e-5 -2.0929913931066486e-5 1.096696475037878e-5 4.878790811238957e-5 -8.470796068246251e-5 5.044555679382031e-5 0.0001277088633873415 -9.476725584641905e-5 0.00012297922900364808 -4.751215213724349e-5 -5.97157120278622e-5; -1.0178621580854946e-5 7.333809500516446e-5 -2.563884643939881e-5 0.00018021192974367184 -0.0001538872091103993 -0.00016907623629129447 -1.0999015815178161e-5 -2.7758227391023533e-5 5.8770452794413774e-5 -6.497828852196294e-5 -0.00010794020616493377 -5.634225434085284e-5 -0.00014400715716227048 0.00019699388232503957 0.00019142407856372257 -2.081155233361784e-5 -0.0001264432428641901 -0.0001031676581832296 2.650968437330401e-5 3.5410296020126436e-5 3.4734468697132907e-5 0.00010580013642276712 -0.0001937787817795484 0.00010511194724778308 -5.499162196691105e-5 0.00029245188508604666 -8.931649043268502e-5 0.0001079216819717835 8.400292446560278e-5 -2.349275362837989e-5 -0.0001768822619695861 0.00014051629861307278; -3.2003747346504775e-5 0.00015890783803123982 -7.968852823866906e-6 6.9865147089360565e-6 -2.5819605380005665e-5 -0.0001629461843002947 -1.8210399812151068e-5 -0.00018865759951941773 -9.402885861974102e-6 -3.762525406262486e-5 -9.540364439076953e-5 -0.00011138898875279679 -9.28091033891881e-5 -0.00020221023472902992 2.695648950983849e-5 0.0001844536233512457 -7.67507811723474e-5 -0.00017211284315626977 -0.0001017176476488292 6.39606511504959e-5 -0.0002505556697905406 2.8997033641725836e-5 -9.108881962664832e-5 0.00013965841904807548 3.374904159508819e-5 5.32913287862365e-5 -1.0918933775432076e-6 -0.00017918210536655316 9.571123488066747e-5 -3.96055914642694e-5 -0.00016766582151823132 -2.3486580484936776e-5; -3.2757407072248886e-5 7.431461383773065e-5 -0.00018415628616978952 9.620045628498102e-6 -7.771467969335536e-5 -0.00021591083529686484 -7.592851213460776e-5 -8.623169739779049e-5 -3.193258268923266e-5 2.972970108046068e-5 8.13002955353514e-6 -2.2232094099065213e-5 -9.666932052980713e-5 1.9112578255109075e-5 8.840297053053974e-5 6.422093977404995e-5 8.579726820996309e-5 0.0001126930128778151 0.00014150559402721017 8.035859903272305e-5 1.7949056226160338e-5 0.00013993647556005052 -9.272902568384767e-5 -3.499442206859542e-5 1.033185983785035e-5 -0.00011463231382706996 0.00012920104748288295 8.827484091695413e-5 -0.0001273585312189715 8.872250876108194e-5 -2.1777143021365626e-5 -0.00014436400121824508; 3.87722613486137e-5 0.00017321202193486254 2.526933366240847e-5 9.086049266258834e-5 4.3574016260447126e-7 -7.128862472384504e-5 3.078256867226928e-5 -8.816573227889144e-5 0.00015166690735711306 0.00016703866293752993 -5.544763899891459e-5 -0.00011443299427994744 -6.227235263580078e-5 7.082161772629682e-5 0.0001665965175452342 3.2478652599794364e-5 7.182027928863202e-5 -0.00011396000610332267 -1.5627137481958232e-6 0.00011987723718263877 -5.192078225448083e-5 -1.1706093402632755e-5 2.6589528886662366e-5 0.00010627506725355858 -0.00011658547720073992 6.409522046667195e-5 -0.00019066976511751614 -7.996417099277712e-5 -0.0001549148944941654 -6.809776806405963e-5 3.483841300110125e-5 -4.2622526791119095e-5; 9.199835112708781e-8 -0.0001699257623153674 4.652146205948998e-5 7.686279309884028e-6 0.00011545409549903748 7.0163541527586645e-6 7.417188360300602e-5 5.2867690533778725e-5 7.070050239740422e-5 0.00019550947079146667 -2.6488672596727237e-5 6.037154205437508e-5 -4.768258761437278e-5 1.4107274615116609e-5 -0.0003214404782961053 1.5078744104858607e-5 4.610047151395561e-5 -0.00013025328328189733 0.00013802600049125718 7.115091327755419e-5 5.917258909496336e-6 0.00011504065375952744 -7.039330332595378e-5 3.920895003424753e-5 -1.879546064123813e-5 -2.141208963116527e-5 1.5772306574537662e-5 2.7854179712042524e-5 -4.520446193459728e-5 -0.0001852661294146372 7.92537034390564e-5 -3.928709724686849e-6], bias = [9.833212068595948e-10, 5.031240309549707e-11, -1.6435266012083702e-10, 3.260358872786313e-10, -1.3013027744622694e-10, -2.1764354454452807e-10, 2.468833374718213e-9, 8.789322527949966e-10, -6.353823975796135e-10, -1.5360561184798549e-9, 2.1397499458253354e-9, 7.942790450239258e-10, -1.1528491373497192e-9, 1.931131083652193e-9, -2.5041799659905464e-9, -5.940973462996402e-10, -1.5521606621617342e-9, -5.9760761632392535e-9, 2.2400585925256843e-9, 4.511030880991829e-9, 4.2133467957509644e-10, -1.3527942429435387e-9, -7.312437670878791e-10, 1.4138545420495813e-9, 1.5115885768711373e-9, -1.4372376140579134e-9, -6.655096163288208e-10, 2.4030407033726136e-10, -4.460471164593126e-9, -5.765939317654227e-10, 5.022120361819514e-10, 5.31577419006508e-10]), layer_4 = (weight = [-0.0006693916946785522 -0.00087811056246193 -0.000658535669691793 -0.000823853132199716 -0.0006709941821680758 -0.000718664462000433 -0.0005830925627937294 -0.0007449466107940803 -0.0006163008050693001 -0.0005307111792350577 -0.0007943188357510313 -0.0007345662026539919 -0.0007325240681090986 -0.0007572817106682904 -0.0005952295872327212 -0.0007137988602273496 -0.0006330283549593247 -0.0006465232184123022 -0.0008911875296839423 -0.0005730279730456309 -0.0005656691506992044 -0.0007595886506220198 -0.0008123492919531509 -0.0006281099806106889 -0.000488916342543989 -0.0007064036913165266 -0.0007752412997183317 -0.0008336575884229319 -0.0005707755612950293 -0.0007412388690887299 -0.0006868003676349305 -0.000706514726664443; 0.0001869810561571447 0.00026138472042665857 0.0003357181368993517 0.00025652838510506704 9.97421041455466e-5 0.00039855819608919016 0.00023919619142872883 0.00035240975832423195 0.00035641583168601526 0.00014358109577566693 0.0003882508683086808 0.00041956713951922596 0.00036676820034798534 0.00014138174049343322 0.00016794329012317466 0.0002400038044527481 0.0002828345627498267 0.0003706013258097437 0.00018089068422404847 0.0001203625082880905 0.00018263306418642709 0.0003410737766694286 0.00020391439791354353 0.00011304708388295691 0.0001623782312510706 0.000203860128433107 0.00030454442330103 0.00023277676229339068 0.00021300622144586494 4.389363309753426e-5 8.112596047742464e-5 8.669705905241166e-5], bias = [-0.0006498800266729953, 0.0002172765629992668]))Visualizing the Results
Let us now plot the loss over time
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")
lines!(ax, losses; linewidth=4, alpha=0.75)
scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)
fig
endFinally let us visualize the results
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
s3 = scatter!(
ax,
tsteps,
waveform_nn_trained;
marker=:circle,
alpha=0.5,
strokewidth=2,
markersize=12,
)
axislegend(
ax,
[[l1, s1], [l2, s2], [l3, s3]],
["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
position=:lb,
)
fig
endAppendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.12.4
Commit 01a2eadb047 (2026-01-06 16:56 UTC)
Build Info:
Official https://julialang.org release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 4 × AMD EPYC 7763 64-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-18.1.7 (ORCJIT, znver3)
GC: Built with stock GC
Threads: 4 default, 1 interactive, 4 GC (on 4 virtual cores)
Environment:
JULIA_DEBUG = Literate
LD_LIBRARY_PATH =
JULIA_NUM_THREADS = 4
JULIA_CPU_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0This page was generated using Literate.jl.