Training a Neural ODE to Model Gravitational Waveforms
This code is adapted from Astroinformatics/ScientificMachineLearning
The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl
Package Imports
using Lux, ComponentArrays, LineSearches, OrdinaryDiffEqLowOrderRK, Optimization,
OptimizationOptimJL, Printf, Random, SciMLSensitivity
using CairoMakie
Define some Utility Functions
Tip
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector
function one2two(path, m₁, m₂)
M = m₁ + m₂
r₁ = m₂ / M .* path
r₂ = -m₁ / M .* path
return r₁, r₂
end
one2two (generic function with 1 method)
Next we define a function to perform the change of variables:
@views function soln2orbit(soln, model_params=nothing)
@assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
if size(soln, 1) == 2
χ = soln[1, :]
ϕ = soln[2, :]
@assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2"
p, M, e = model_params
else
χ = soln[1, :]
ϕ = soln[2, :]
p = soln[3, :]
e = soln[4, :]
end
r = p ./ (1 .+ e .* cos.(χ))
x = r .* cos.(ϕ)
y = r .* sin.(ϕ)
orbit = vcat(x', y')
return orbit
end
soln2orbit (generic function with 2 methods)
This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d_dt(v::AbstractVector, dt)
a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
b = (v[3:end] .- v[1:(end - 2)]) / 2
c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
return [a; b; c] / dt
end
d_dt (generic function with 1 method)
This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d2_dt2(v::AbstractVector, dt)
a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)
Now we define a function to compute the trace-free moment tensor from the orbit
function orbit2tensor(orbit, component, mass=1)
x = orbit[1, :]
y = orbit[2, :]
Ixx = x .^ 2
Iyy = y .^ 2
Ixy = x .* y
trace = Ixx .+ Iyy
if component[1] == 1 && component[2] == 1
tmp = Ixx .- trace ./ 3
elseif component[1] == 2 && component[2] == 2
tmp = Iyy .- trace ./ 3
else
tmp = Ixy
end
return mass .* tmp
end
function h_22_quadrupole_components(dt, orbit, component, mass=1)
mtensor = orbit2tensor(orbit, component, mass)
mtensor_ddot = d2_dt2(mtensor, dt)
return 2 * mtensor_ddot
end
function h_22_quadrupole(dt, orbit, mass=1)
h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
return h11, h12, h22
end
function h_22_strain_one_body(dt::T, orbit) where {T}
h11, h12, h22 = h_22_quadrupole(dt, orbit)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
h11 = h11_1 + h11_2
h12 = h12_1 + h12_2
h22 = h22_1 + h22_2
return h11, h12, h22
end
function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
# compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2
@assert abs(mass1 + mass2 - 1.0)<1e-12 "Masses do not sum to unity"
h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
@assert mass_ratio≤1 "mass_ratio must be <= 1"
@assert mass_ratio≥0 "mass_ratio must be non-negative"
orbit = soln2orbit(soln, model_params)
if mass_ratio > 0
m₂ = inv(T(1) + mass_ratio)
m₁ = mass_ratio * m₂
orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
return waveform
end
compute_waveform (generic function with 2 methods)
Simulating the True Model
RelativisticOrbitModel
defines system of odes which describes motion of point like particle in schwarzschild background, uses
where,
function RelativisticOrbitModel(u, (p, M, e), t)
χ, ϕ = u
numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
denom = sqrt((p - 2)^2 - 4 * e^2)
χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = numer / (M * (p^(3 / 2)) * denom)
return [χ̇, ϕ̇]
end
mass_ratio = 0.0 # test particle
u0 = Float64[π, 0.0] # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4) # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize) # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e
Let's simulate the true model and plot the results using OrdinaryDiffEq.jl
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)
axislegend(ax, [[l, s]], ["Waveform Data"])
fig
end
Defiing a Neural Network Model
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model
that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but incase you do use them, make sure to mark them as const
.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl
,
const nn = Chain(Base.Fix1(fast_activation, cos),
Dense(1 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32),
Dense(32 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32),
Dense(32 => 2; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32))
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[3.3762408f-5; 9.377616f-5; 6.3974345f-5; -0.000110069304; -0.0001446427; -5.673938f-5; -5.6276687f-5; -0.00012045649; 7.0082664f-5; 6.0663122f-5; -0.00011987695; 0.0002174891; 0.00014940342; 3.1554844f-6; -4.8612637f-5; 7.799579f-5; 5.9642484f-6; -7.402087f-5; -2.3927209f-5; 0.00020453366; 5.863988f-5; 0.00018471539; -5.4077347f-5; 2.1376543f-5; 1.4348437f-5; -1.7542276f-5; 1.9830833f-5; -0.000101488455; 9.802708f-5; -4.774987f-5; -4.392182f-5; -0.00015505681;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-3.167835f-5 -0.00017326728 2.1456239f-5 3.0559465f-6 8.670264f-5 2.1129179f-5 8.137731f-5 2.2439874f-5 2.1673586f-5 -9.720361f-5 -6.561089f-5 -7.939323f-5 -1.1814809f-5 -0.00012521364 -6.9161515f-6 2.2955115f-5 9.199664f-5 -9.001195f-5 -5.724703f-5 0.00019338753 -8.155465f-5 -0.0001060708 3.4722056f-5 0.0001655086 -0.00014083457 -0.00010901833 5.5821452f-5 5.1261355f-5 7.777724f-5 -1.457528f-5 -3.0184789f-5 3.0344683f-5; -8.7531276f-5 -0.00015944398 8.823492f-5 0.00014740053 7.6892495f-5 2.8300368f-5 0.00019259054 1.3385513f-5 -0.00019908906 0.00017361736 -5.7279784f-5 0.000121739824 0.000101182515 0.00011984331 6.5729164f-6 3.7271355f-5 5.529983f-6 8.933739f-5 6.0020848f-5 0.00011047311 -0.00019905614 -0.00018193282 -8.5353684f-5 -6.6237946f-5 1.1670401f-5 -0.00018943203 -1.05960535f-5 3.9650065f-5 -6.8248046f-5 -7.978318f-5 2.643545f-5 -1.6347192f-5; 7.9900165f-5 -3.3791195f-5 -4.461751f-5 8.198608f-5 -8.8874986f-5 5.1846946f-5 4.9173075f-5 2.1226306f-5 6.1615894f-5 5.2050676f-5 8.6817075f-5 5.0423812f-5 -0.0001565413 4.3377513f-5 -4.9897844f-5 0.00010143122 -3.865422f-5 0.00010972347 -9.389309f-5 -3.9290477f-5 -0.00011096744 -9.6759264f-5 -0.0001328251 -5.0886763f-5 -4.8319594f-5 6.971236f-5 9.5485724f-5 -4.2470463f-5 -2.2584534f-5 -0.00014199188 0.00010841379 3.4423352f-5; -4.0085677f-5 5.846745f-5 -3.457356f-5 -3.6869521f-6 7.9050403f-7 -8.290168f-5 0.00012858551 7.235046f-5 0.0003001503 2.2835804f-5 4.4499127f-5 0.00011936336 9.2190356f-5 -0.00010763649 9.156982f-5 -0.00015461422 -0.00018267597 0.00011015288 -7.390641f-5 -2.6609117f-5 5.914308f-5 6.25383f-5 0.00013434085 -0.00011991429 -9.973299f-6 -1.4919258f-6 5.9757043f-5 -5.8327118f-5 4.3632324f-5 2.6743108f-5 3.547425f-5 2.4756053f-5; -0.00010693439 -3.2330692f-5 0.00013549883 0.00016848063 -4.4991302f-5 1.6312048f-6 5.308744f-5 -0.00014600328 5.1391162f-5 -1.7519033f-5 6.6223634f-5 -1.0482116f-6 0.00012321475 -6.4843494f-5 -8.60347f-5 7.831225f-5 -5.1964314f-5 -9.3510746f-5 1.973614f-5 -5.969374f-5 0.00019920986 -0.00015265046 -3.649828f-5 1.37049465f-5 0.000109683104 0.00010686412 7.906872f-5 3.469373f-5 -7.4391144f-5 2.1635075f-5 6.394174f-5 -6.66925f-5; 5.961434f-6 5.683609f-5 -1.8216513f-5 7.305108f-5 -8.100786f-5 3.500533f-5 0.00013942244 2.8661281f-5 -2.4896002f-5 -0.00018749913 5.7354482f-5 1.9849773f-5 8.694509f-5 3.3665707f-5 4.3606826f-5 -0.00011435062 3.3737735f-5 -0.00010216421 -4.5987123f-5 -2.4170382f-5 0.0001958456 -3.7196503f-5 8.751697f-5 5.6783803f-5 -0.00017161861 8.976308f-5 1.9316223f-5 0.00021105785 4.522042f-5 -2.7782213f-5 9.955028f-5 0.00039455056; 0.00015191818 -6.089402f-5 0.00013581767 -4.411443f-5 2.239457f-5 -6.502179f-5 5.0726035f-6 -6.363439f-5 0.00014961598 0.00013880347 -5.055871f-5 -0.00023794136 -6.273292f-5 -8.237891f-5 -2.5595862f-5 5.224788f-5 2.8504228f-6 3.136851f-5 4.044638f-5 -0.00011080653 -0.00016883439 5.347921f-5 0.00010104189 -0.0001889074 0.000113251175 -0.00012653953 -4.8131635f-5 -0.000102627964 6.6101784f-5 -8.509704f-5 -2.3070636f-5 3.3479577f-5; -8.750504f-5 2.9171926f-5 0.00022042167 -0.0001931004 0.00015282453 7.9011756f-5 0.00018839244 -0.00016074357 -0.00012011719 -0.0002386075 -8.297239f-5 -1.8730367f-5 -7.690997f-5 6.211139f-5 0.000111024245 -8.48213f-5 -0.00024238117 0.0002331874 0.00020022548 2.7068552f-5 -7.6184864f-5 -1.4874527f-5 -8.710071f-7 -0.00010648643 -3.342965f-5 0.00026645238 0.00011235834 -0.0001327587 -9.5918855f-5 5.35745f-7 -8.6018495f-5 -2.139972f-5; -5.2919328f-5 3.726837f-5 -5.533571f-5 3.1335163f-5 -6.62358f-5 -2.6162641f-5 1.0649018f-5 0.00016081057 5.0430885f-6 -4.234112f-5 -0.0001000101 -9.433519f-5 -5.4847966f-5 0.000294391 -1.1442187f-5 3.6263587f-5 0.0002745655 1.8441973f-5 -7.722601f-5 -1.9540525f-5 2.3949342f-6 -9.0249894f-5 -5.565176f-5 0.00021469402 -0.00012355864 2.9499757f-5 -0.00017732766 4.4193155f-5 -6.565675f-5 7.0172384f-5 3.3551405f-5 -3.5457144f-6; -0.00010798544 -4.9830025f-5 6.0300645f-5 7.685219f-5 -9.018578f-5 -0.00010038452 -8.031163f-5 -7.609693f-5 0.00011582041 0.00017185153 -5.787316f-5 2.3187496f-5 3.9537917f-5 -3.254601f-5 4.0907857f-5 -0.0001247215 -1.2792882f-5 -4.136524f-5 5.163726f-6 4.176694f-5 -2.517533f-5 9.350496f-5 -0.00016413647 -5.4292792f-5 -4.6962738f-5 9.311813f-5 -0.0002546033 -7.556855f-5 0.0001528439 1.272239f-5 -0.00010227689 -0.00012112792; 2.5478654f-5 -1.8925972f-5 1.3522798f-5 7.9053636f-5 -1.9931056f-5 -0.00016947588 -2.5421677f-5 0.00023060103 1.7039229f-5 -0.00012146738 6.9841312f-6 -8.291776f-6 0.0002417237 -4.4259476f-5 -9.1143105f-5 0.00013496332 0.00019254164 0.00012783299 -4.6915415f-5 -3.905643f-5 -0.00011279506 -8.80961f-5 -5.5542707f-5 7.8158926f-5 -2.3275632f-5 1.8245255f-5 -8.382284f-5 -0.00013564869 -8.543443f-5 -0.00019596776 -1.9594525f-5 0.0001296699; -0.00011558923 1.3172018f-5 2.3640616f-5 -1.3789628f-6 -9.748229f-5 7.7190365f-5 -1.2993842f-5 -0.0001229382 -0.00015768768 -7.704191f-5 0.00021275727 -7.256848f-5 -2.8602703f-5 -7.2770745f-5 -4.810583f-5 -8.694891f-5 0.0001419347 5.4706696f-5 -2.9066492f-5 -6.2310115f-5 -0.00015988392 0.00013972251 0.00023841354 4.878433f-6 2.1148624f-5 2.7470458f-5 -0.00010487495 -2.4921412f-5 2.357649f-5 -5.891703f-6 0.00014575128 4.2540898f-5; 2.09383f-5 -0.00017817278 7.1572917f-6 -0.000108034874 -6.4670116f-5 3.5526818f-6 -0.00012206268 -2.5206886f-5 4.9970757f-5 2.7443182f-5 -7.563465f-5 -2.0181504f-5 -5.954069f-5 5.604793f-5 5.825142f-5 -2.4075583f-5 -0.00017586861 0.00012281955 4.6896792f-5 -0.00012155193 -0.00015880697 6.9155685f-5 7.9987214f-5 5.610593f-6 -0.000116922005 9.0508256f-5 -0.00013882587 2.1072745f-5 -8.1383005f-5 -0.00010210284 4.1260406f-5 -8.6248565f-6; -0.00029144282 5.9778853f-5 -2.5181285f-5 -8.5456495f-6 -5.384766f-5 -0.00018108908 7.8775985f-5 -7.4875235f-5 -2.700818f-5 -2.4299436f-5 -3.128644f-5 -2.0955818f-5 2.1625636f-5 5.70831f-5 4.623904f-5 -4.7818983f-5 -0.00014378567 -0.00013122981 7.108028f-5 -3.5296376f-5 -7.0865215f-5 8.400256f-5 9.037617f-6 -9.841278f-5 -1.2612928f-5 -9.9217f-5 -0.00013333991 -5.685295f-5 -2.0524003f-5 -9.537318f-5 0.0001021231 -0.00024132116; 0.00013585645 -3.3810116f-5 0.00014966111 2.482841f-6 -0.00011383134 0.00017412056 -0.00012280005 0.00010608431 0.000119174605 -2.2118495f-5 -2.0155996f-5 6.255633f-5 0.00015060208 5.0649556f-5 -0.0001349138 -5.75867f-5 -6.1194434f-5 7.8931145f-5 4.810006f-5 7.814027f-5 0.00017248065 5.1117942f-5 5.6450357f-5 7.465614f-5 -0.00014229953 0.00016677874 -5.6629557f-5 -0.0001210063 1.3957021f-5 -0.00019529514 -1.3669461f-5 2.1667754f-6; 0.000106820495 -0.00010066837 -6.977883f-5 0.00011087112 4.0440464f-6 6.853523f-5 -0.00015421756 5.421576f-5 3.6845377f-5 -0.00011023941 2.6635293f-5 -0.00012771956 6.3490916f-5 5.984918f-6 -0.000117452226 -3.3283653f-5 0.00023612365 -0.00016214218 0.00016155963 -0.00017134123 -7.523442f-5 -8.746304f-5 -7.328123f-6 1.6101205f-6 -2.3032846f-5 0.00017523281 3.665834f-5 7.939483f-5 -5.3760665f-5 0.00015392218 1.4726962f-5 -4.6564703f-5; -6.0400933f-5 5.3426418f-5 7.659585f-6 2.173049f-5 -0.00017288668 6.746041f-5 -4.6870504f-5 -0.00012182365 -6.528104f-5 -4.8989153f-7 0.0001951214 -3.8070462f-5 -0.00015769474 -9.363801f-5 -7.6088254f-5 5.6629502f-5 4.9076214f-5 -0.00023589538 -2.1136435f-5 -1.668819f-5 0.00016326466 7.959219f-5 -4.673233f-5 -0.00017696091 -2.5152138f-5 -4.7545276f-5 -0.000116819465 -1.686003f-5 -0.000117502714 2.1703123f-5 -9.827086f-5 -9.332105f-5; -3.0508694f-5 -8.346137f-5 0.000112655405 8.562069f-5 -5.7245798f-5 0.00012665142 -6.502684f-5 0.00011166917 7.77302f-5 2.7728067f-5 -0.00013862271 -0.00017658203 1.9176246f-5 3.0980787f-5 4.5688772f-5 -6.4600235f-6 -0.00010869345 -2.2893946f-5 -8.853777f-5 7.884912f-5 -9.2184055f-5 -1.799283f-5 -0.00012120603 0.00014078588 0.00031047917 -1.2479958f-5 -7.15276f-5 7.245281f-5 -0.000113480346 -6.782734f-5 7.361784f-5 4.490223f-6; -0.00017959323 3.594937f-5 -6.952185f-6 -3.9215174f-5 6.7309884f-5 -9.705446f-5 0.00012202641 -4.824043f-5 9.203587f-5 4.7241658f-5 -6.5930703f-6 -0.00019555441 0.00011294411 4.43076f-5 -2.1831385f-5 6.416468f-5 -4.0306517f-5 4.2001026f-5 4.4623488f-5 9.325437f-5 -0.00011736338 -9.3871f-5 -7.8429926f-5 5.0306848f-5 -3.2532524f-5 -7.5049684f-6 0.00011726638 0.00012003368 2.5677859f-5 0.00011481601 -4.3209595f-5 -9.573343f-6; -4.465724f-5 0.00012467794 -0.0001745078 0.000137037 0.00022906155 5.972709f-5 -4.689057f-5 0.00013774176 -9.282728f-6 1.7477296f-5 -7.081772f-5 5.087943f-5 0.00010610425 2.4112329f-5 -2.8728364f-5 -4.077014f-5 8.04921f-5 -5.017782f-5 -9.510996f-5 -2.9785173f-5 3.242094f-5 6.490675f-5 -4.4944205f-5 2.1774717f-5 -5.2659118f-5 5.691948f-5 7.062748f-5 8.985865f-5 -0.00012179623 -2.5220117f-5 2.5510724f-5 -0.0002175639; -1.8064444f-5 0.00011100678 -0.00012345429 -6.8100087f-7 -6.29078f-5 2.4937353f-5 0.000110023386 6.0152997f-5 1.976271f-6 9.334225f-5 -4.274406f-5 -5.786669f-5 0.00018662655 -4.9786806f-5 -3.6834215f-5 -3.8196493f-5 0.00015880959 5.664257f-5 -3.497574f-5 4.00917f-5 9.9629295f-5 -7.87449f-5 9.905165f-5 6.700846f-5 -0.00016735493 0.00020870792 0.00011272673 -3.4326695f-5 3.5691122f-5 -5.517219f-6 0.00010668557 5.9691884f-5; -0.00012167978 -9.575844f-5 -0.00012283474 1.8576848f-5 8.6771666f-5 -0.00010718201 3.0729672f-5 -1.7757171f-6 2.4523291f-5 4.6492583f-5 -2.06563f-5 -2.0301126f-5 -6.641399f-5 -5.249451f-5 -3.967492f-5 -0.00011529061 -3.3575085f-5 -7.832703f-5 -7.086144f-5 4.7096055f-5 -7.746475f-5 -1.2006147f-5 0.00012437307 0.00020168748 -8.130576f-5 5.6309244f-5 9.310638f-6 -4.1956857f-5 5.5005083f-5 -5.25454f-6 3.8032635f-5 -0.00013170215; -2.6851441f-5 0.00017783718 4.5854224f-5 -3.6386835f-5 -0.00013216895 -3.197749f-5 0.00014737278 -8.5781445f-5 1.444617f-5 -7.3324256f-5 -9.9262885f-5 -2.3837525f-5 -7.674186f-5 2.9352233f-5 7.101991f-5 -0.0003128677 4.4337976f-5 3.8268918f-5 2.1603017f-5 -2.464503f-6 -3.051723f-5 -0.000124008 0.00016717735 -4.8534508f-5 -2.9526771f-5 -0.00019507401 1.5154287f-5 4.0400348f-5 -3.7619866f-5 -1.4056087f-5 -8.219701f-5 8.658503f-5; -0.00011823278 -0.00010530102 -2.9943392f-5 -0.00011595517 9.647378f-5 0.00018556356 7.303799f-5 -7.895071f-5 -7.2233735f-5 5.7866342f-5 7.4216814f-6 -7.697309f-5 0.00012779947 -5.6239674f-5 0.00012688992 -0.00014173011 4.6928566f-5 -7.251575f-5 2.717461f-5 -8.0355756f-5 0.00010897228 -7.866558f-5 -4.510712f-6 9.637687f-5 -2.452513f-7 -1.0320719f-5 -4.534992f-5 7.328235f-5 0.0001222505 2.1149528f-5 2.6855143f-5 -8.261682f-5; -5.7380912f-5 9.119424f-6 8.353689f-5 -0.00011725962 -3.829574f-5 -1.724173f-5 0.00011100296 -1.0220104f-5 0.00010768657 -0.00014718034 -4.1799776f-5 -0.00016172924 2.75562f-5 5.1200346f-5 0.00013696666 -0.00016442579 2.2160524f-5 -6.747441f-5 -7.15149f-5 -3.682068f-5 -0.0001041959 -5.924469f-5 1.0310588f-6 -5.5798748f-5 -7.1036164f-5 3.6621637f-5 -0.00017607115 -7.1527866f-5 -2.904257f-5 -0.00016707851 -0.000105729894 5.3624874f-5; 0.000114194445 -8.918752f-5 -3.297131f-5 -2.377171f-5 -0.00011117759 -0.00013393247 3.2435124f-5 -0.00010416534 -9.933769f-5 -0.000123389 0.00026072905 -6.225539f-5 0.0001110872 -0.00021235549 1.5761654f-5 -0.00011595071 4.9918202f-5 2.6204661f-5 -0.00014675439 -4.2643223f-5 -3.7517715f-5 -3.0077406f-6 3.134166f-5 0.00013063697 5.1043124f-5 0.000112253074 -9.566826f-5 0.00016561401 2.605319f-5 -1.9149367f-5 1.3506367f-7 2.0405687f-5; -4.657069f-5 -1.9460578f-5 -6.854052f-5 4.2229578f-5 -0.00014128443 -9.757619f-5 3.4389694f-5 0.000119865654 -0.00012298222 -7.8954945f-5 -2.8915567f-5 5.9533173f-5 2.519481f-5 4.8866827f-5 -0.00015636756 3.60406f-5 -4.0035815f-5 0.00014140503 0.00016875725 -2.8803375f-5 -6.313255f-5 -5.681532f-5 5.6181623f-5 -1.841565f-5 -8.923271f-5 -8.390456f-5 3.8213962f-5 5.2830426f-5 -2.9212863f-6 0.00012916065 -0.00010364941 -8.4441526f-5; 0.00010339336 0.00017781714 0.00013627746 2.6357045f-5 -0.00010808135 -6.177129f-5 3.6617712f-5 5.4338107f-5 -0.00021452455 -6.394711f-5 1.0330123f-5 -7.9564605f-5 -5.338983f-6 -5.1792693f-5 0.00011016393 -4.44936f-5 5.891852f-5 -7.5776796f-7 4.605146f-5 3.6284798f-6 8.589625f-5 -2.0119602f-5 -9.7660624f-5 4.575942f-5 1.02594995f-5 5.306162f-5 3.0164147f-5 -0.00011281836 9.330938f-5 -6.824483f-6 0.00010446661 0.00013396057; -0.0001951888 5.8384354f-5 -0.00012530261 -1.5879175f-5 2.934209f-5 1.1002939f-5 -6.6203844f-5 -4.360819f-5 0.000111181915 -9.908264f-5 6.734528f-6 1.705266f-5 0.00019367308 2.2095099f-5 -9.271286f-5 -2.1095551f-5 1.7197346f-5 -5.9032376f-5 0.00010082366 0.00022012803 -4.9459522f-5 8.038054f-5 1.3113048f-5 -5.73554f-5 -9.914133f-5 -3.299627f-5 6.887422f-5 -5.0654384f-5 -0.00010898708 -3.8197413f-5 -0.00012262804 -1.4406404f-5; 3.1107324f-5 0.00010170236 1.426584f-5 9.6826465f-5 -0.00031170406 2.2756863f-5 6.5724475f-5 -6.666957f-5 1.3804084f-5 -9.155169f-5 0.00011222639 1.0309875f-5 8.8811095f-5 -4.7896756f-5 -0.000111197885 -4.5646295f-5 -0.00019733357 -5.115883f-5 -0.000120401164 0.00015094364 0.00014136104 -0.00015634643 -0.00013227183 4.6890967f-5 4.7300713f-5 0.00014805992 -7.214264f-5 2.913857f-5 7.2481045f-5 4.2200036f-6 -9.626872f-5 0.00010495324; 1.0661233f-5 6.9693483f-6 0.00021526254 -5.869133f-6 -0.00017352536 -1.41400615f-5 -5.1640523f-5 1.1182987f-5 -0.0001205724 -0.00010635317 3.9127324f-5 -2.5970165f-5 5.983971f-5 0.00015770781 -0.00014540597 -4.7801197f-5 7.4078715f-5 -0.000116320916 -0.00016742683 0.00023071725 -0.00010765882 0.00013153923 -2.443311f-5 -2.023532f-5 -9.393517f-5 9.2195725f-5 0.00013583958 1.7185928f-5 7.486675f-6 9.722826f-5 8.751207f-5 -7.011925f-5; -0.00023978678 -6.1876206f-5 9.9080644f-5 -0.00014268699 8.1891776f-5 -0.00020162198 -5.257116f-5 -3.541215f-6 -9.830737f-5 2.0778584f-6 -0.00014887418 -2.2630404f-5 6.791644f-5 9.3562165f-5 -0.00012339842 0.0001757439 0.00011024978 -0.00015698405 1.566887f-5 0.00010369461 -4.2116135f-5 0.00015388132 -5.5769633f-5 -4.5952893f-6 -9.38876f-6 2.78696f-5 -1.3529659f-5 3.1242432f-6 -0.00017411486 0.00011940141 8.766942f-5 3.7920858f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[7.895329f-6 -3.925668f-5 5.045617f-5 -7.968984f-5 -0.00013499377 1.2882123f-5 -5.9575228f-5 -7.801501f-5 0.000106755935 -0.0002604102 -7.749079f-5 -7.7674566f-5 1.0697102f-5 -5.8125926f-5 8.9086636f-5 -8.653789f-5 8.915331f-5 -5.6801484f-5 -5.418811f-5 -1.3612775f-5 5.475511f-5 0.00011464535 -6.745991f-5 2.6499602f-5 -9.849866f-6 -7.360977f-5 -3.198457f-5 1.4259968f-5 -0.00010739438 -6.891933f-5 -4.6188936f-5 -9.7742326f-5; -7.106096f-5 0.00013050942 -6.3523716f-5 1.3326511f-5 -5.0962455f-5 4.193433f-5 -8.490918f-6 -0.00017491188 -9.500565f-6 6.1635487f-6 -8.277729f-6 -9.7031276f-5 -2.5881818f-5 -1.7820346f-5 -0.00015816202 -9.700425f-5 -6.2374384f-6 -8.512483f-5 -0.00014334935 -0.00017616172 9.259983f-5 -7.032087f-5 0.000118795215 0.000139581 -0.00018006409 7.828696f-6 2.6457557f-5 5.705327f-5 -6.24685f-5 -9.859035f-5 -2.3041932f-5 9.864874f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))
Similar to most DL frameworks, Lux defaults to using Float32
, however, in this case we need Float64
const params = ComponentArray(ps |> f64)
const nn_model = StatefulLuxLayer{true}(nn, nothing, st)
StatefulLuxLayer{true}(
Chain(
layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
layer_2 = Dense(1 => 32, cos), # 64 parameters
layer_3 = Dense(32 => 32, cos), # 1_056 parameters
layer_4 = Dense(32 => 2), # 66 parameters
),
) # Total: 1_186 parameters,
# plus 0 states.
Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses
where,
function ODE_model(u, nn_params, t)
χ, ϕ = u
p, M, e = ode_model_params
# In this example we know that `st` is am empty NamedTuple hence we can safely ignore
# it, however, in general, we should use `st` to store the state of the neural network.
y = 1 .+ nn_model([first(u)], nn_params)
numer = (1 + e * cos(χ))^2
denom = M * (p^(3 / 2))
χ̇ = (numer / denom) * y[1]
ϕ̇ = (numer / denom) * y[2]
return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)
Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)
axislegend(ax, [[l1, s1], [l2, s2]],
["Waveform Data", "Waveform Neural Net (Untrained)"]; position=:lb)
fig
end
Setting Up for Training the Neural Network
Next, we define the objective (loss) function to be minimized when training the neural differential equations.
const mseloss = MSELoss()
function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
return mseloss(pred_waveform, waveform)
end
loss (generic function with 1 method)
Warmup the loss function
loss(params)
0.0006580560382857572
Now let us define a callback function to store the loss over time
const losses = Float64[]
function callback(θ, l)
push!(losses, l)
@printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
return false
end
callback (generic function with 1 method)
Training the Neural Network
Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
optprob, BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
callback, maxiters=1000)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [3.37624078383064e-5; 9.377615788245451e-5; 6.397434481188669e-5; -0.00011006930435531318; -0.00014464270498127274; -5.673937994280611e-5; -5.627668724626745e-5; -0.00012045649054907243; 7.008266402408173e-5; 6.066312198525873e-5; -0.00011987695324932118; 0.00021748909784925704; 0.0001494034222557913; 3.1554843644697083e-6; -4.861263732892413e-5; 7.799579179841923e-5; 5.96424843024077e-6; -7.402087067013206e-5; -2.3927208530959386e-5; 0.0002045336586888318; 5.863987826147611e-5; 0.00018471539078759573; -5.4077347158448226e-5; 2.1376543372725575e-5; 1.4348436707224675e-5; -1.754227560011416e-5; 1.9830833480182585e-5; -0.00010148845467476113; 9.802707791088757e-5; -4.774986882690953e-5; -4.392182017904155e-5; -0.0001550568122180735;;], bias = [7.117870665631646e-17, 8.821904558349283e-17, -2.220175956605897e-17, -1.260329517256922e-16, -1.259015395897539e-16, -9.024414196184708e-17, -1.5816051636664314e-17, -2.1840592629026906e-16, 3.2692340749519613e-18, 1.547584345009391e-17, -1.3280014060949561e-17, 2.471441071052053e-16, 2.4208547052080566e-16, 1.9378275016771824e-19, 1.106981834710812e-17, 4.6286869834108276e-17, 5.847209627678565e-18, -9.171663973846917e-17, 1.1644803157364875e-17, 6.485625351265259e-17, 1.4737196710404254e-16, -1.5849390201636644e-16, -7.339000212276461e-18, 3.218211677913231e-17, -1.380572060540901e-18, -2.4432463830136888e-17, 5.321242268596277e-17, -8.892117525311091e-17, 7.875631921909812e-17, -3.706233292912627e-17, -6.097227847579628e-17, -3.687698568447955e-16]), layer_3 = (weight = [-3.167920233713838e-5 -0.00017326813333816058 2.1455388079310638e-5 3.0550958596459885e-6 8.670178856087093e-5 2.112832832758463e-5 8.137645697538424e-5 2.2439022970185588e-5 2.167273548516175e-5 -9.720445753223568e-5 -6.561174025441528e-5 -7.939407744592075e-5 -1.1815659789450176e-5 -0.00012521449328083017 -6.917002181900817e-6 2.2954264451663432e-5 9.199578987286018e-5 -9.001279916296965e-5 -5.724788145776298e-5 0.00019338668123205957 -8.155550251863276e-5 -0.00010607165252471729 3.47212050397941e-5 0.00016550774368774684 -0.00014083542130570329 -0.00010901917933402051 5.5820601730054006e-5 5.126050443473152e-5 7.777638590746421e-5 -1.4576130823435903e-5 -3.0185639232833685e-5 3.034383250601853e-5; -8.753110306440155e-5 -0.00015944380867164866 8.823509268825622e-5 0.0001474006989052489 7.689266784712325e-5 2.830054144272334e-5 0.0001925907112205376 1.3385685987150442e-5 -0.00019908888644874355 0.0001736175312440817 -5.72796103473318e-5 0.00012173999711069712 0.00010118268836652504 0.00011984348237877483 6.573089548174667e-6 3.7271527970579626e-5 5.530156057494503e-6 8.933756033787246e-5 6.002102089826718e-5 0.00011047328588449861 -0.00019905597001649693 -0.000181932644055788 -8.535351081388316e-5 -6.623777299989246e-5 1.1670574520968399e-5 -0.0001894318571534031 -1.0595880313639041e-5 3.9650238603490334e-5 -6.824787270613297e-5 -7.978300856063384e-5 2.6435622556738027e-5 -1.634701912859821e-5; 7.989982477971371e-5 -3.3791534949558686e-5 -4.461785065126555e-5 8.198573996429474e-5 -8.887532530701481e-5 5.184660622268025e-5 4.917273545521352e-5 2.1225965977065997e-5 6.161555429804198e-5 5.205033667385624e-5 8.68167357135115e-5 5.042347256909656e-5 -0.00015654163653147155 4.337717328104196e-5 -4.9898184076882584e-5 0.00010143088220735153 -3.8654559464193824e-5 0.00010972313102916879 -9.389342958223994e-5 -3.929081649169787e-5 -0.00011096778107224919 -9.675960390944581e-5 -0.00013282543440931442 -5.088710313197215e-5 -4.831993428291745e-5 6.971202165237637e-5 9.548538437462421e-5 -4.247080287081069e-5 -2.2584874020539035e-5 -0.0001419922242325262 0.0001084134513826578 3.4423012405171e-5; -4.0083620865433876e-5 5.846950514304356e-5 -3.457150453406161e-5 -3.684895977964916e-6 7.925601949645937e-7 -8.28996228164313e-5 0.0001285875667009252 7.235251976006181e-5 0.00030015236259267836 2.2837860137850707e-5 4.450118356097016e-5 0.00011936541411619718 9.21924118075139e-5 -0.00010763443717670125 9.157187448644422e-5 -0.00015461215955375605 -0.00018267391280906705 0.00011015493944349827 -7.390435641801453e-5 -2.6607061249651118e-5 5.914513601518506e-5 6.254035977241632e-5 0.0001343429073814054 -0.00011991223730084015 -9.971242638956527e-6 -1.4898696235001113e-6 5.975909950856594e-5 -5.832506142261635e-5 4.3634380540501694e-5 2.6745163973000585e-5 3.5476307653052716e-5 2.4758109043499327e-5; -0.0001069333848911553 -3.2329684550199513e-5 0.00013549984206722663 0.00016848164151699556 -4.499029463229718e-5 1.6322124244903243e-6 5.308844621570423e-5 -0.00014600227230883006 5.139216947529981e-5 -1.751802491096223e-5 6.622464199128763e-5 -1.0472039869351326e-6 0.00012321575551570328 -6.48424866860944e-5 -8.603369140953156e-5 7.831325994150257e-5 -5.1963306837270704e-5 -9.350973785810608e-5 1.973714817243357e-5 -5.969273123381441e-5 0.00019921086269707988 -0.00015264945621791748 -3.649727303606807e-5 1.3705954110927065e-5 0.00010968411144379266 0.00010686512897766862 7.906972670273588e-5 3.469473805555478e-5 -7.439013651062374e-5 2.1636082160264697e-5 6.394274888837393e-5 -6.669148941480675e-5; 5.964892587597201e-6 5.683954692220906e-5 -1.821305453768737e-5 7.305453869638246e-5 -8.100440503906317e-5 3.500878916247111e-5 0.00013942589681255 2.8664739837478243e-5 -2.48925436985762e-5 -0.00018749566992619764 5.7357940712319246e-5 1.9853231386550946e-5 8.6948547285348e-5 3.366916531240296e-5 4.361028437278624e-5 -0.00011434716474605756 3.374119365480454e-5 -0.00010216075064955243 -4.598366398704574e-5 -2.4166923359692145e-5 0.00019584905636156247 -3.719304481050431e-5 8.752043027788728e-5 5.6787261890803034e-5 -0.00017161515028556174 8.97665402212615e-5 1.931968177677748e-5 0.00021106130839915435 4.522387713085426e-5 -2.7778754194062854e-5 9.955373639831762e-5 0.00039455401816528067; 0.00015191661157204644 -6.0895591672488536e-5 0.0001358160960769936 -4.411600013936423e-5 2.239299908597459e-5 -6.50233624803183e-5 5.07103266117728e-6 -6.363596105456714e-5 0.00014961441127142944 0.00013880190169929536 -5.056027941845727e-5 -0.00023794293440385317 -6.27344917340386e-5 -8.238047951108955e-5 -2.5597432908584024e-5 5.2246309565633016e-5 2.8488519781553895e-6 3.1366937833624986e-5 4.044480817883124e-5 -0.00011080810429501224 -0.00016883596174393784 5.347763806258984e-5 0.00010104032196595305 -0.00018890897547335522 0.00011324960439723748 -0.00012654109661908055 -4.8133205495292076e-5 -0.00010262953503497524 6.610021307624262e-5 -8.509860992872218e-5 -2.307220631719648e-5 3.3478006003686616e-5; -8.750571341474388e-5 2.9171251847998682e-5 0.00022042099826981534 -0.00019310106781298826 0.00015282385957440025 7.901108129195137e-5 0.0001883917671908924 -0.000160744244018444 -0.00012011786511388369 -0.0002386081809130197 -8.29730610315364e-5 -1.873104101499299e-5 -7.691064583590564e-5 6.211071250041144e-5 0.0001110235703924037 -8.482197644875122e-5 -0.00024238184177770148 0.0002331867241115614 0.0002002248025779906 2.7067878003130964e-5 -7.618553845199552e-5 -1.4875200937546803e-5 -8.716813989346948e-7 -0.00010648710067138869 -3.343032348888792e-5 0.00026645170383167584 0.00011235766815266709 -0.00013275936853128126 -9.59195289363771e-5 5.350707203025162e-7 -8.601916979054748e-5 -2.1400393412786454e-5; -5.291879743883789e-5 3.7268901804121446e-5 -5.533517934229679e-5 3.133569306380453e-5 -6.623526961623505e-5 -2.6162110872906367e-5 1.0649548028663542e-5 0.0001608110968737241 5.043618683043143e-6 -4.234058825379685e-5 -0.00010000956636934243 -9.433466140029294e-5 -5.484743552363316e-5 0.0002943915377619079 -1.1441657181353825e-5 3.6264117523454825e-5 0.00027456603755861135 1.8442503358595877e-5 -7.722547990117615e-5 -1.95399944568024e-5 2.395464333016654e-6 -9.024936399503654e-5 -5.565122875116268e-5 0.0002146945543136781 -0.00012355810860019563 2.9500287280893724e-5 -0.0001773271297787914 4.419368532405604e-5 -6.565621980546593e-5 7.017291399104099e-5 3.355193520117135e-5 -3.545184301337429e-6; -0.00010798776255480803 -4.983234665498335e-5 6.02983231269512e-5 7.684987159653625e-5 -9.018809995405566e-5 -0.00010038683884570073 -8.031394880756438e-5 -7.609924849613328e-5 0.00011581808886465119 0.00017184920526549561 -5.7875481240001954e-5 2.3185174917229317e-5 3.9535595909449376e-5 -3.254833117037027e-5 4.090553503690524e-5 -0.00012472381839554234 -1.2795204007967277e-5 -4.136756305026345e-5 5.161404636486946e-6 4.176461826637883e-5 -2.5177651501983936e-5 9.350263955048183e-5 -0.00016413878673713373 -5.429511346531011e-5 -4.696505910626543e-5 9.311580598794001e-5 -0.00025460562809396457 -7.55708684541945e-5 0.00015284158090569876 1.2720068674276088e-5 -0.00010227921355774422 -0.00012113024475775608; 2.547834252154092e-5 -1.8926282919988077e-5 1.3522487150925615e-5 7.905332482450391e-5 -1.9931367333905504e-5 -0.0001694761941881605 -2.5421988667499898e-5 0.00023060071713088784 1.703891744809367e-5 -0.00012146769417186605 6.9838200569351335e-6 -8.292087060090903e-6 0.00024172338622419946 -4.425978671961839e-5 -9.114341641280221e-5 0.00013496301276872555 0.00019254133245239595 0.0001278326761702323 -4.69157258895449e-5 -3.905674215375391e-5 -0.00011279537115215664 -8.809641271100405e-5 -5.55430180075799e-5 7.815861486860181e-5 -2.3275943150207398e-5 1.824494380243217e-5 -8.38231482167486e-5 -0.0001356490012231962 -8.543474372445882e-5 -0.0001959680722767243 -1.9594836104357437e-5 0.0001296695935333414; -0.0001155896320473527 1.3171617423347427e-5 2.3640215216425326e-5 -1.3793634128997511e-6 -9.748269120465082e-5 7.718996449838111e-5 -1.2994242284924767e-5 -0.00012293860389619264 -0.00015768808269641621 -7.704230709933951e-5 0.00021275686550651545 -7.256888101110869e-5 -2.860310313409226e-5 -7.27711453568316e-5 -4.810623112135861e-5 -8.694930919447228e-5 0.0001419342948007272 5.4706295044149534e-5 -2.9066892681306617e-5 -6.231051564682982e-5 -0.00015988431595034242 0.00013972211264771535 0.00023841314351883433 4.878032553812088e-6 2.1148223380429234e-5 2.7470057387823747e-5 -0.00010487534790301084 -2.4921812370314296e-5 2.3576088563993452e-5 -5.8921036601654925e-6 0.00014575088381533214 4.254049737354733e-5; 2.0935172508836868e-5 -0.0001781759041450903 7.154164543831873e-6 -0.00010803800105079519 -6.46732427385496e-5 3.549554642599961e-6 -0.00012206580722434876 -2.521001297011236e-5 4.9967629459358116e-5 2.7440055057891607e-5 -7.563777616908055e-5 -2.0184630918252228e-5 -5.95438163118283e-5 5.604480393500246e-5 5.8248291318672245e-5 -2.40787107004797e-5 -0.00017587173933592446 0.00012281641975547925 4.689366469661941e-5 -0.00012155505682766621 -0.00015881009894386334 6.915255799231075e-5 7.998408691761563e-5 5.607465575814623e-6 -0.00011692513216716039 9.050512890367976e-5 -0.000138828995111033 2.1069617653641155e-5 -8.138613193025579e-5 -0.00010210596373966483 4.1257278647175085e-5 -8.627983695088506e-6; -0.00029144770532802446 5.977396640739001e-5 -2.518617193053704e-5 -8.550536128292861e-6 -5.3852546349627295e-5 -0.00018109396923754733 7.87710988363027e-5 -7.488012206228862e-5 -2.7013066662029392e-5 -2.4304322229719866e-5 -3.129132610117521e-5 -2.0960705020765525e-5 2.1620749160539434e-5 5.7078213197384495e-5 4.623415507303977e-5 -4.782386976947915e-5 -0.00014379055655517567 -0.0001312346984348447 7.107539122735286e-5 -3.530126277539617e-5 -7.08701017145287e-5 8.399767384913701e-5 9.032730398036575e-6 -9.84176630452543e-5 -1.2617814624534607e-5 -9.922188919226544e-5 -0.00013334479890254219 -5.685783698643718e-5 -2.052888966935026e-5 -9.53780662683077e-5 0.00010211821345656534 -0.00024132604894180594; 0.00013585862448637457 -3.380793954746327e-5 0.00014966328431584744 2.4850177000311845e-6 -0.00011382916149486246 0.00017412274006021143 -0.00012279787064797305 0.00010608649004681099 0.00011917678180581704 -2.2116317853560722e-5 -2.0153819281029694e-5 6.255851017739228e-5 0.00015060425481024637 5.065173295990726e-5 -0.00013491162703174172 -5.758452484186077e-5 -6.119225749314605e-5 7.893332199857343e-5 4.810223741189699e-5 7.814244723375877e-5 0.0001724828265254657 5.112011909331733e-5 5.645253382355861e-5 7.46583185504483e-5 -0.00014229734974249118 0.00016678092047770648 -5.662737989366851e-5 -0.0001210041215413959 1.3959197855379363e-5 -0.00019529296573053112 -1.3667283968587622e-5 2.1689521476256817e-6; 0.00010682048223750515 -0.0001006683793959595 -6.977883976580725e-5 0.00011087110970472771 4.044033907827436e-6 6.853521545488238e-5 -0.00015421757206715792 5.42157488823947e-5 3.684536414291561e-5 -0.0001102394214242153 2.6635280214374377e-5 -0.0001277195676222829 6.349090318033329e-5 5.984905605852497e-6 -0.00011745223827822285 -3.328366592106557e-5 0.00023612364230282454 -0.00016214219585456102 0.00016155961412075756 -0.00017134124727383388 -7.523443282045574e-5 -8.746305474707993e-5 -7.328135502778787e-6 1.6101080789454403e-6 -2.3032858457757726e-5 0.00017523279739614425 3.66583283764854e-5 7.939481502469482e-5 -5.3760677197008925e-5 0.0001539221723803863 1.4726949579905155e-5 -4.656471509881458e-5; -6.0404998675362947e-5 5.342235233970837e-5 7.655519385160801e-6 2.1726423784148773e-5 -0.0001728907431418761 6.745634307721738e-5 -4.687456968836385e-5 -0.00012182771608362225 -6.52851071384702e-5 -4.939573562767068e-7 0.00019511734133690324 -3.807452814661509e-5 -0.00015769880557793333 -9.364207509802973e-5 -7.609231981346391e-5 5.662543624626031e-5 4.907214823164662e-5 -0.00023589944468708297 -2.114050077521486e-5 -1.6692254957248112e-5 0.0001632605941091678 7.958812200803807e-5 -4.673639561529202e-5 -0.00017696497381562628 -2.5156203645889778e-5 -4.754934199750702e-5 -0.00011682353071533411 -1.6864096705232482e-5 -0.00011750677951508174 2.1699057088559334e-5 -9.827492469353664e-5 -9.332511255645108e-5; -3.0508539835630143e-5 -8.34612131218301e-5 0.00011265555959130978 8.56208441330243e-5 -5.724564316164446e-5 0.0001266515698300802 -6.502668683946719e-5 0.00011166932536457952 7.773035186278971e-5 2.772822127031506e-5 -0.00013862255427351863 -0.00017658187998290052 1.9176400841621377e-5 3.0980941759714746e-5 4.568892627890572e-5 -6.459869196753899e-6 -0.00010869329552928557 -2.289379180705772e-5 -8.853761236945741e-5 7.884927772852744e-5 -9.218390032014464e-5 -1.7992675835325363e-5 -0.00012120587625371877 0.00014078603426629103 0.00031047932289244033 -1.2479803713852199e-5 -7.152744210156253e-5 7.245296156202751e-5 -0.000113480191699403 -6.782718837325112e-5 7.36179951711686e-5 4.49037743624656e-6; -0.00017959261003720306 3.5949994527484465e-5 -6.951561771116944e-6 -3.9214551178291004e-5 6.731050689400037e-5 -9.705383399572931e-5 0.00012202703237694913 -4.823980543601332e-5 9.203649551768977e-5 4.724228074469819e-5 -6.592447333376654e-6 -0.00019555379094395533 0.00011294473097354941 4.430822173294902e-5 -2.1830762305895236e-5 6.416529959228987e-5 -4.030589388871275e-5 4.2001648581802766e-5 4.462411107074749e-5 9.325499286337192e-5 -0.00011736275609886985 -9.387037698484434e-5 -7.842930335622983e-5 5.0307470702550394e-5 -3.253190061381731e-5 -7.504345374896144e-6 0.00011726700631830191 0.00012003430038153274 2.567848150779588e-5 0.00011481663114086968 -4.3208971872941876e-5 -9.572720044662388e-6; -4.465625370704655e-5 0.00012467893047522416 -0.00017450681680549116 0.0001370379884939969 0.00022906253961644342 5.972807575397972e-5 -4.688958343086915e-5 0.00013774275230042117 -9.28174013444197e-6 1.7478283367551118e-5 -7.081673011618249e-5 5.0880416707669456e-5 0.00010610523695157072 2.4113316427417537e-5 -2.8727376233121695e-5 -4.07691533872044e-5 8.049308762218753e-5 -5.017683197712752e-5 -9.510896980265222e-5 -2.978418542973234e-5 3.2421927058918666e-5 6.490773903003676e-5 -4.494321747535052e-5 2.1775704041090164e-5 -5.265812997441794e-5 5.692046924758644e-5 7.062846798015697e-5 8.985963886169336e-5 -0.00012179523965079695 -2.521912956928025e-5 2.5511711825117293e-5 -0.00021756290716786763; -1.8061316714671872e-5 0.00011100990946876173 -0.00012345115987111974 -6.778737913365786e-7 -6.290467148642085e-5 2.4940480451978964e-5 0.00011002651286550734 6.015612439956214e-5 1.9793981293960162e-6 9.334537595284052e-5 -4.274093384150704e-5 -5.786356427122792e-5 0.00018662967910063277 -4.978367882424706e-5 -3.6831088201077484e-5 -3.819336574518487e-5 0.00015881271816331217 5.6645696767435175e-5 -3.497261217555761e-5 4.009482859985582e-5 9.963242178789245e-5 -7.874177548798258e-5 9.905477623694536e-5 6.701158987132113e-5 -0.0001673517997375945 0.0002087110463064839 0.00011272985853727195 -3.432356765255201e-5 3.56942493992971e-5 -5.514091972807399e-6 0.0001066886964138324 5.969501058575345e-5; -0.00012168180696454949 -9.576046581633932e-5 -0.00012283676337239256 1.857482343774895e-5 8.67696420334119e-5 -0.00010718403213678841 3.072764807592162e-5 -1.7777413370887331e-6 2.4521267144960726e-5 4.6490559137678266e-5 -2.0658324180028302e-5 -2.0303150128059586e-5 -6.641601216533803e-5 -5.2496534086747166e-5 -3.967694433078207e-5 -0.00011529263578842186 -3.3577108885144865e-5 -7.832905398237352e-5 -7.086346308489487e-5 4.709403070020265e-5 -7.746677296956038e-5 -1.2008171003156233e-5 0.0001243710415598657 0.00020168545447966848 -8.1307787373878e-5 5.630721930509259e-5 9.308614020194578e-6 -4.195888108944034e-5 5.5003058472438475e-5 -5.256564369406986e-6 3.803061040678927e-5 -0.0001317041766435979; -2.6853392156556196e-5 0.0001778352298776036 4.585227277733768e-5 -3.638878563047418e-5 -0.00013217089956137746 -3.197944253894174e-5 0.0001473708244508444 -8.578339559003587e-5 1.4444219303768751e-5 -7.332620673092659e-5 -9.926483555441608e-5 -2.3839475871981956e-5 -7.674381133378784e-5 2.9350282674055795e-5 7.101795843358421e-5 -0.0003128696416036827 4.4336025056178944e-5 3.826696691125172e-5 2.1601065853826382e-5 -2.4664538353440067e-6 -3.0519181492637234e-5 -0.00012400995267065714 0.00016717539900004238 -4.8536458890669446e-5 -2.9528721753525173e-5 -0.0001950759623663207 1.5152336417156304e-5 4.039839684869015e-5 -3.76218167014998e-5 -1.4058037705910073e-5 -8.219896055503506e-5 8.65830798635706e-5; -0.00011823268514377249 -0.00010530092194927133 -2.994329415628878e-5 -0.00011595507030393216 9.647387551125823e-5 0.00018556366185406968 7.303808879555012e-5 -7.8950612168007e-5 -7.223363727349053e-5 5.786644023941639e-5 7.421779538383787e-6 -7.697299233655777e-5 0.0001277995724205927 -5.623957631428905e-5 0.00012689001951115953 -0.00014173001190898203 4.6928664143018496e-5 -7.25156533906163e-5 2.7174708430942213e-5 -8.035565779096761e-5 0.00010897237725770842 -7.8665481941023e-5 -4.5106140521571655e-6 9.637696703179557e-5 -2.451531645479263e-7 -1.03206212759292e-5 -4.534982153814466e-5 7.328244455606463e-5 0.0001222505924499666 2.114962617383153e-5 2.6855241140986813e-5 -8.261671982723545e-5; -5.738495502437144e-5 9.115381263938552e-6 8.353284892228213e-5 -0.0001172636596572222 -3.829978350829193e-5 -1.7245772422662566e-5 0.00011099891952772283 -1.0224146816362708e-5 0.0001076825234952651 -0.00014718438370318426 -4.1803819211672795e-5 -0.0001617332866850412 2.7552156726716542e-5 5.119630322160371e-5 0.00013696261402519903 -0.0001644298293364882 2.2156481182991417e-5 -6.74784508270867e-5 -7.151894288535916e-5 -3.682472134518328e-5 -0.00010419993969544145 -5.924873158496782e-5 1.0270158020861375e-6 -5.580279076521054e-5 -7.104020670221922e-5 3.6617594279323016e-5 -0.0001760751886728943 -7.153190864181476e-5 -2.904661351629955e-5 -0.0001670825561711331 -0.00010573393711519932 5.362083055964293e-5; 0.00011419338053378704 -8.918858137626435e-5 -3.297237552765227e-5 -2.377277477357656e-5 -0.00011117865714190339 -0.00013393353064025253 3.2434059641626986e-5 -0.00010416640205953304 -9.933875325082242e-5 -0.00012339007142877622 0.0002607279872899684 -6.225645692032578e-5 0.00011108613917891176 -0.0002123565521179008 1.5760589769296383e-5 -0.00011595177264830428 4.99171380506122e-5 2.6203597109952205e-5 -0.0001467554559786112 -4.2644287794278995e-5 -3.751877947074785e-5 -3.0088049589166e-6 3.134059598955815e-5 0.00013063590241649587 5.104205930542594e-5 0.00011225200952317033 -9.56693277549238e-5 0.00016561294406706285 2.6052126224341594e-5 -1.915043166086543e-5 1.3399930146412936e-7 2.0404622511666475e-5; -4.657202135263436e-5 -1.9461908783568966e-5 -6.854184812823604e-5 4.222824736075532e-5 -0.0001412857591333196 -9.757751741404046e-5 3.438836301370187e-5 0.00011986432309072627 -0.00012298355341613558 -7.895627550738246e-5 -2.8916897468886648e-5 5.9531842447330405e-5 2.5193478899160783e-5 4.886549594970877e-5 -0.00015636889202710617 3.6039270520403765e-5 -4.003714548730765e-5 0.00014140370421387017 0.00016875592063456322 -2.8804705840442698e-5 -6.313388349213029e-5 -5.681664970892392e-5 5.6180292625637976e-5 -1.8416981225758687e-5 -8.923403857466299e-5 -8.39058911935845e-5 3.821263180156712e-5 5.282909568791255e-5 -2.922616891964526e-6 0.0001291593152870945 -0.0001036507374067271 -8.444285686550082e-5; 0.00010339495814180756 0.0001778187422408571 0.00013627906330869198 2.6358645054937227e-5 -0.00010807975310825102 -6.17696871537837e-5 3.661931144081702e-5 5.433970612132039e-5 -0.0002145229530193974 -6.39455113466022e-5 1.0331722233170269e-5 -7.956300511950444e-5 -5.337383545049373e-6 -5.17910932737042e-5 0.00011016552774010914 -4.449200175228968e-5 5.892011789031206e-5 -7.561684080176898e-7 4.6053059786873316e-5 3.6300793729484418e-6 8.589785168653986e-5 -2.0118002130119665e-5 -9.765902474980142e-5 4.576102103813409e-5 1.0261099060090362e-5 5.306321939345314e-5 3.0165746220819193e-5 -0.00011281675803700191 9.331098292962022e-5 -6.822883524064731e-6 0.00010446821282160248 0.00013396216557368183; -0.00019518998732463748 5.8383163660541275e-5 -0.0001253038032599649 -1.5880365653069218e-5 2.9340900538154152e-5 1.100174841768598e-5 -6.620503398961597e-5 -4.360938027173978e-5 0.00011118072444434136 -9.908382770744342e-5 6.7333378966548105e-6 1.7051469795779634e-5 0.0001936718938951528 2.2093908511224584e-5 -9.271404695836137e-5 -2.1096741592270356e-5 1.7196155850897935e-5 -5.903356640272894e-5 0.00010082246745797493 0.0002201268392228745 -4.946071266102804e-5 8.03793475926116e-5 1.311185798032222e-5 -5.735658913989868e-5 -9.914252285751841e-5 -3.299746157507735e-5 6.887303180620687e-5 -5.065557405840536e-5 -0.00010898826866549398 -3.819860347802065e-5 -0.00012262923400056808 -1.4407594507137282e-5; 3.110662851962061e-5 0.00010170166547889796 1.4265144962548363e-5 9.682576979524508e-5 -0.0003117047565335848 2.275616839891148e-5 6.572377947871421e-5 -6.667026558876864e-5 1.3803389045490225e-5 -9.155238171094238e-5 0.00011222569788354551 1.0309179522885653e-5 8.881040026448022e-5 -4.789745093309544e-5 -0.00011119858049075291 -4.5646989967069525e-5 -0.0001973342698345449 -5.1159524666006735e-5 -0.00012040185924140017 0.0001509429455456863 0.00014136034929673042 -0.0001563471271235906 -0.00013227252588127896 4.6890272422598396e-5 4.730001797564108e-5 0.0001480592216087018 -7.214333549360962e-5 2.913787444252531e-5 7.248035013456823e-5 4.219308501189245e-6 -9.626941413604837e-5 0.00010495254151304576; 1.0661525189375096e-5 6.969640176561218e-6 0.00021526283027519126 -5.868841148600513e-6 -0.0001735250690075713 -1.4139769595618034e-5 -5.164023158109155e-5 1.1183278652918698e-5 -0.00012057210468306563 -0.00010635287534064957 3.912761614223494e-5 -2.5969872771802547e-5 5.984000100901757e-5 0.0001577080992461058 -0.0001454056756160348 -4.7800905198942824e-5 7.407900640423101e-5 -0.00011632062440590365 -0.00016742653595834352 0.00023071754632432234 -0.00010765853138259953 0.0001315395210581675 -2.4432817630855012e-5 -2.0235028463917196e-5 -9.393487478924294e-5 9.219601717226827e-5 0.00013583987394262726 1.7186220171009768e-5 7.486966674645142e-6 9.722854967143986e-5 8.75123595647388e-5 -7.011895566798355e-5; -0.00023978806106213574 -6.187748780546478e-5 9.90793620624769e-5 -0.00014268826756931471 8.189049427558483e-5 -0.00020162326230973012 -5.257244023567074e-5 -3.5424967242520363e-6 -9.83086514130973e-5 2.07657663879348e-6 -0.00014887546543802596 -2.26316854783892e-5 6.791515503239693e-5 9.356088290616728e-5 -0.00012339970393411854 0.00017574262043155842 0.00011024849542111017 -0.00015698533510629 1.566758752262102e-5 0.0001036933305479936 -4.211741665101606e-5 0.00015388004127120044 -5.57709148230778e-5 -4.596571070287676e-6 -9.39004197760386e-6 2.7868318884720663e-5 -1.3530940431956063e-5 3.122961434496144e-6 -0.00017411614075142623 0.00011940012942461251 8.766813576614422e-5 3.7919575934512023e-5], bias = [-8.506761811410372e-10, 1.731574512539204e-10, -3.3978489222454153e-10, 2.0561629677160185e-9, 1.0076425758105899e-9, 3.458588707346979e-9, -1.5708103142163625e-9, -6.742960481078557e-10, 5.301336160730268e-10, -2.3215639213336712e-9, -3.1117552334398165e-10, -4.0062005957119884e-10, -3.1272038892844717e-9, -4.886621171162243e-9, 2.17674498608252e-9, -1.2455733133337248e-11, -4.065822720002597e-9, 1.5434512299056505e-10, 6.230051937684992e-10, 9.875258798491752e-10, 3.1270785935189968e-9, -2.024213533926109e-9, -1.95079461554098e-9, 9.813560925909336e-11, -4.04298471116304e-9, -1.0643707829089821e-9, -1.3305964368420732e-9, 1.599555917762126e-9, -1.1902456112175505e-9, -6.950738396617751e-10, 2.9187098767668754e-10, -1.2817869722193323e-9]), layer_4 = (weight = [-0.0006554016998485922 -0.000702553726828611 -0.0006128408762509192 -0.0007429867703139176 -0.0007982907861590612 -0.0006504146067971043 -0.0007228722089322738 -0.0007413120432565619 -0.0005565411054862673 -0.0009237070795225915 -0.0007407878345329968 -0.0007409716089891216 -0.0006525996881839295 -0.0007214223278821621 -0.0005742102929645153 -0.0007498349374932685 -0.0005741433108816683 -0.0007200985310096784 -0.0007174851465327 -0.0006769097970452216 -0.0006085416800648222 -0.0005486515940217652 -0.0007307568491233101 -0.0006367974459926958 -0.0006731464893434357 -0.0007369067885951262 -0.0006952815712581744 -0.0006490370112956622 -0.0007706913875310307 -0.000732216367783692 -0.0007094859812871634 -0.0007610393278592855; 0.00018925534560954754 0.0003908257321338698 0.00019679259638699192 0.0002736427790489343 0.00020935384705780157 0.0003022505193287231 0.0002518253690979293 8.54044265539941e-5 0.00025081574512060887 0.00026647980001672016 0.0002520385834327286 0.00016328503571406362 0.00023443439429129533 0.00024249571407389877 0.00010215425036905607 0.00016331206031074984 0.00025407870892524525 0.0001751914796797077 0.0001169669638703935 8.415458647740514e-5 0.0003529160443410187 0.00018999540123540505 0.00037911148674237303 0.0003998973089968661 8.02520578996219e-5 0.00026814499683858336 0.000286773851920574 0.00031736955494657867 0.00019784779995368429 0.00016172595468562904 0.00023727438045456663 0.00035896503824177085], bias = [-0.0006632970477686623, 0.00026031631320577487]))
Visualizing the Results
Let us now plot the loss over time
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")
lines!(ax, losses; linewidth=4, alpha=0.75)
scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)
fig
end
Finally let us visualize the results
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(compute_waveform(
dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)
l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
s3 = scatter!(ax, tsteps, waveform_nn_trained; marker=:circle,
alpha=0.5, strokewidth=2, markersize=12)
axislegend(ax, [[l1, s1], [l2, s2], [l3, s3]],
["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
position=:lb)
fig
end
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.10.6
Commit 67dffc4a8ae (2024-10-28 12:23 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 128 × AMD EPYC 7502 32-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
JULIA_CPU_THREADS = 16
JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 16
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.