Training a Neural ODE to Model Gravitational Waveforms
This code is adapted from Astroinformatics/ScientificMachineLearning
The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl
Package Imports
using Lux, ComponentArrays, LineSearches, OrdinaryDiffEqLowOrderRK, Optimization,
OptimizationOptimJL, Printf, Random, SciMLSensitivity
using CairoMakieDefine some Utility Functions
Tip
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector
function one2two(path, m₁, m₂)
M = m₁ + m₂
r₁ = m₂ / M .* path
r₂ = -m₁ / M .* path
return r₁, r₂
endone2two (generic function with 1 method)Next we define a function to perform the change of variables:
@views function soln2orbit(soln, model_params=nothing)
@assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
if size(soln, 1) == 2
χ = soln[1, :]
ϕ = soln[2, :]
@assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2"
p, M, e = model_params
else
χ = soln[1, :]
ϕ = soln[2, :]
p = soln[3, :]
e = soln[4, :]
end
r = p ./ (1 .+ e .* cos.(χ))
x = r .* cos.(ϕ)
y = r .* sin.(ϕ)
orbit = vcat(x', y')
return orbit
endsoln2orbit (generic function with 2 methods)This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d_dt(v::AbstractVector, dt)
a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
b = (v[3:end] .- v[1:(end - 2)]) / 2
c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
return [a; b; c] / dt
endd_dt (generic function with 1 method)This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d2_dt2(v::AbstractVector, dt)
a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
return [a; b; c] / (dt^2)
endd2_dt2 (generic function with 1 method)Now we define a function to compute the trace-free moment tensor from the orbit
function orbit2tensor(orbit, component, mass=1)
x = orbit[1, :]
y = orbit[2, :]
Ixx = x .^ 2
Iyy = y .^ 2
Ixy = x .* y
trace = Ixx .+ Iyy
if component[1] == 1 && component[2] == 1
tmp = Ixx .- trace ./ 3
elseif component[1] == 2 && component[2] == 2
tmp = Iyy .- trace ./ 3
else
tmp = Ixy
end
return mass .* tmp
end
function h_22_quadrupole_components(dt, orbit, component, mass=1)
mtensor = orbit2tensor(orbit, component, mass)
mtensor_ddot = d2_dt2(mtensor, dt)
return 2 * mtensor_ddot
end
function h_22_quadrupole(dt, orbit, mass=1)
h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
return h11, h12, h22
end
function h_22_strain_one_body(dt::T, orbit) where {T}
h11, h12, h22 = h_22_quadrupole(dt, orbit)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
h11 = h11_1 + h11_2
h12 = h12_1 + h12_2
h22 = h22_1 + h22_2
return h11, h12, h22
end
function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
# compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2
@assert abs(mass1 + mass2 - 1.0)<1e-12 "Masses do not sum to unity"
h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
@assert mass_ratio≤1 "mass_ratio must be <= 1"
@assert mass_ratio≥0 "mass_ratio must be non-negative"
orbit = soln2orbit(soln, model_params)
if mass_ratio > 0
m₂ = inv(T(1) + mass_ratio)
m₁ = mass_ratio * m₂
orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
return waveform
endcompute_waveform (generic function with 2 methods)Simulating the True Model
RelativisticOrbitModel defines system of odes which describes motion of point like particle in schwarzschild background, uses
where,
function RelativisticOrbitModel(u, (p, M, e), t)
χ, ϕ = u
numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
denom = sqrt((p - 2)^2 - 4 * e^2)
χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = numer / (M * (p^(3 / 2)) * denom)
return [χ̇, ϕ̇]
end
mass_ratio = 0.0 # test particle
u0 = Float64[π, 0.0] # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4) # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize) # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, eLet's simulate the true model and plot the results using OrdinaryDiffEq.jl
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)
axislegend(ax, [[l, s]], ["Waveform Data"])
fig
endDefiing a Neural Network Model
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,
const nn = Chain(Base.Fix1(fast_activation, cos),
Dense(1 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32),
Dense(32 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32),
Dense(32 => 2; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32))
ps, st = Lux.setup(Random.default_rng(), nn)((layer_1 = NamedTuple(), layer_2 = (weight = Float32[3.1807685f-5; -7.478387f-5; -0.00013932164; -0.0001994294; -5.125871f-5; -2.9746048f-5; -7.506302f-5; -8.228871f-5; 2.630744f-5; 2.3313152f-5; -7.9803096f-5; -7.9281555f-5; -3.0228102f-5; -7.0136935f-5; 0.0001744172; -6.931676f-5; -4.9532595f-5; -0.00010964114; 0.00010660198; -0.00027348325; -0.000165167; 1.5331696f-5; 1.4763274f-5; -0.00010183243; 7.2731f-6; -6.982019f-5; 0.00019025181; 0.00016863248; 1.9166857f-5; 2.3672128f-5; -0.00015290278; -0.00013966892;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-9.419614f-6 1.6962982f-5 2.4874213f-5 -9.9840254f-5 2.3327138f-5 0.00017426224 -8.206071f-5 0.00014668741 6.1731385f-5 -2.1662525f-5 -0.0001990326 0.00011555621 -0.00017975779 -8.918488f-5 -3.698763f-5 7.087705f-5 -2.2031369f-5 3.6486013f-5 -6.944577f-6 1.815504f-5 -2.997702f-5 6.826566f-5 7.378279f-5 -1.7591749f-6 0.00022217642 -4.549864f-5 -7.137862f-5 4.8350008f-5 0.000108710745 9.796076f-6 4.5668028f-5 -0.00012333709; 8.039799f-5 0.00015739768 -5.109813f-5 -7.299806f-5 -5.0352643f-5 4.266957f-5 0.00014580895 8.179729f-5 9.029429f-5 -3.182214f-5 -4.3988355f-5 -1.1044426f-5 0.00012496718 -0.00017860085 2.2243905f-5 -7.306905f-5 -0.00015341825 -9.07805f-5 3.2522916f-5 -0.00013883336 3.162155f-5 -0.00014949245 -0.00013814912 0.00012135943 -6.846765f-5 7.808051f-5 -6.5136104f-5 0.00013601722 -4.984178f-6 -8.993093f-6 -6.2983774f-5 3.751475f-5; -0.00010560871 0.00011795948 -3.3382046f-5 -0.00020257487 -0.00015878571 -7.312033f-6 -0.00013377014 4.4057426f-5 -5.7340376f-6 -0.0002540387 3.2656284f-5 0.000110622284 -2.8668188f-5 -3.6539306f-5 7.308661f-5 0.00012936947 0.00018241077 0.00039527824 0.00020417557 -5.8964273f-5 -0.0001890423 -4.4394204f-5 6.23621f-5 -6.350553f-5 -9.037222f-5 0.00014755399 6.191423f-5 0.00020010729 -8.9959205f-5 2.3628276f-5 4.4924254f-5 -0.00018283135; -9.776359f-5 7.5712596f-5 -0.00024980865 -8.394768f-5 -5.3182543f-5 2.0644684f-5 2.9030916f-5 4.6260837f-5 4.8290774f-5 0.00017839154 8.745083f-5 0.0001496326 -4.8347334f-5 8.724962f-6 -7.922653f-5 1.6248583f-5 3.4423163f-5 -2.8986265f-6 -6.1506245f-5 -7.4973206f-5 0.00015170837 -9.24601f-5 0.0001271547 7.625675f-5 -1.028811f-5 -2.40686f-5 -0.0001612006 6.624136f-5 0.00016434133 -0.00018287724 9.726827f-5 3.333253f-5; 3.0434952f-5 -7.969743f-6 4.1268784f-5 0.00015085947 6.136497f-5 -7.497967f-5 0.000113932816 -6.890024f-5 -0.00010180078 -4.2753425f-5 -3.1313139f-6 -0.00013250447 3.5734436f-5 1.9791187f-5 0.00011524669 7.1124276f-5 7.6573786f-5 -0.00024749222 -4.7446232f-5 -0.00018119696 -6.449911f-5 0.00017699788 -1.5407695f-5 1.2343729f-5 1.49141215f-5 0.00012619335 0.00017147572 -1.4431571f-5 0.00012937706 -4.8494192f-5 6.980917f-5 5.297435f-5; -0.00011008984 -0.00016392955 2.3400957f-5 6.099548f-5 -0.00011588234 -9.5399126f-5 0.0001201623 0.0001645474 -1.6794627f-5 -0.00012645563 0.00012927226 -8.684403f-5 6.154587f-5 2.7419164f-5 -0.00010456312 4.915822f-5 -8.931812f-5 0.0001930848 -0.00012879808 9.706188f-5 -1.3812532f-5 -6.5831126f-5 -0.00014810645 3.7250295f-5 0.00022287219 0.000107079344 6.341614f-5 -0.00021722738 -0.00024795518 -3.601029f-5 3.3905097f-5 5.1587056f-5; -7.262537f-5 -0.00011800459 -3.2786298f-5 0.00017160887 3.0090881f-5 1.695898f-5 7.641452f-5 0.00014773259 -4.70004f-5 8.7868175f-5 -3.3233766f-5 0.00010239946 1.5246984f-5 7.0370456f-6 0.000148152 0.00011484299 5.249498f-6 8.638954f-6 7.3374267f-6 -1.5767764f-5 1.0793497f-6 -2.789528f-5 -3.6802943f-5 0.00019801199 -5.947666f-5 0.0002384887 -3.9158338f-5 0.00013957304 4.946269f-5 -0.00028676188 -7.1725066f-5 -5.9446247f-5; -6.8923124f-5 6.587844f-6 7.2126335f-5 0.00011449135 4.550966f-5 0.00019023396 2.0195885f-5 -3.0542367f-5 0.0001947772 4.9680562f-5 -7.975845f-5 -1.620562f-5 -0.0001419919 -0.000114933064 5.48319f-5 0.00016607917 -0.00014224905 -0.00011054016 -3.0811912f-5 -9.170084f-5 -4.972705f-5 5.80198f-6 -3.7480862f-5 0.00016072395 8.080061f-5 0.00021398475 0.000117214964 -6.950616f-5 5.87479f-5 -0.00019909078 -8.953526f-5 -7.581265f-5; -4.840312f-5 -8.2621686f-5 -4.49316f-5 7.5081975f-5 8.467991f-6 -5.37115f-5 0.00012294791 -4.297997f-5 -4.943285f-5 -8.1642545f-5 -0.00013856679 -8.2377395f-5 -0.00012578654 0.00015829792 -2.9539662f-5 -7.531633f-5 2.6366773f-5 -0.00012700555 7.6693505f-5 0.00014013964 8.883366f-5 -2.0799394f-5 -0.00011748569 3.4159617f-5 -4.388802f-5 -0.00012422267 -6.413228f-5 0.00013096319 3.2445707f-5 2.1333075f-5 -0.00021000937 2.1981343f-5; -1.5824102f-5 -6.4888874f-5 -6.731258f-5 -0.00012846055 6.218412f-5 -0.00023130779 2.6752964f-6 6.520045f-5 7.357666f-6 0.00014972682 -0.00021585899 6.942313f-5 -1.2485771f-5 4.5415323f-5 -5.4846743f-5 -3.1276513f-6 0.00017251115 -0.00015399889 -6.8612964f-5 6.335614f-5 -7.945291f-5 -0.000119406715 0.00012706594 -1.556356f-5 -9.19884f-5 -1.362808f-5 -5.3103468f-5 7.2307026f-5 4.172762f-5 0.00011630259 -0.0002695408 0.00021776167; -5.5443143f-5 -0.00011674891 -0.00025261482 -4.897752f-5 3.6662503f-5 -2.8132035f-5 2.9033205f-5 9.7382566f-5 0.00013820388 -9.836202f-5 -0.00020601269 -0.0001514092 3.2181273f-5 -8.789867f-5 0.00010430673 -0.00028183564 -0.0001617664 2.1798369f-6 -0.00013387465 -0.00012303847 2.8488668f-5 9.5071446f-5 0.00022341206 3.1773063f-5 0.00013302616 -4.3462445f-5 -0.00011630099 -9.969521f-5 0.00025271229 -0.00020057714 1.3220564f-5 -2.2141458f-5; 2.0378578f-5 -5.826367f-5 0.000117358584 -4.0922052f-5 -0.00015870856 0.0001579918 -3.817774f-5 2.0935078f-5 5.0241815f-5 -0.00013056798 9.761097f-6 3.7293692f-5 9.61427f-5 5.0930503f-5 3.7664224f-5 -3.998287f-5 0.00014764514 -4.178581f-6 -1.5442542f-6 0.00012396075 7.816657f-5 -3.9972285f-5 -0.00013572317 0.00010011973 1.7661785f-5 -3.5323745f-5 5.777434f-5 -0.00013851996 2.1445978f-5 -1.8040341f-6 8.28099f-5 4.258403f-5; 0.00011891162 -9.2466325f-5 5.7312536f-6 0.00022869489 0.000103698534 -8.669533f-5 0.00012530015 -0.00012441102 6.628562f-5 0.00015123305 -0.00010571882 -1.9307909f-5 -4.55907f-5 -7.015956f-5 2.1628603f-5 8.640439f-6 4.1252566f-5 -0.00015185442 3.299222f-5 -0.00013390426 8.184988f-5 -9.5240444f-5 -8.360292f-5 -0.00011921281 8.596764f-5 0.00017634318 -0.00020421114 5.4450215f-5 -0.00015572902 1.3416491f-5 -6.907622f-5 0.00016909408; 4.9597793f-6 0.00013463917 -0.00010821749 6.822323f-5 5.1592782f-5 1.9479976f-6 0.0001066587 -1.8376282f-5 6.674456f-5 5.288038f-6 0.00011460563 4.1542546f-5 8.276034f-5 0.00011978708 6.398128f-5 -2.2903594f-5 -2.5301872f-5 0.0001432219 -4.760648f-5 8.129872f-6 -0.00010679282 -5.4666565f-5 5.0130973f-5 6.4070024f-5 -9.278796f-5 -0.00021501513 -4.0361898f-5 5.1279138f-5 -9.572067f-6 -0.00010937381 0.00017354633 0.00010335834; 0.0001351504 0.00012930568 -0.00012679522 -6.6732144f-5 0.00012423845 0.00026082358 1.0762569f-6 -0.00021562936 -6.23111f-5 0.00014169069 -5.486158f-5 2.2607692f-5 -1.4978002f-5 -0.00015079191 0.000112590744 4.2754586f-5 -5.1032763f-5 -0.00012996585 -1.0540809f-5 7.087941f-5 -0.00013234516 3.7820377f-5 -5.0662944f-5 3.0910225f-5 4.9177917f-5 -3.4298246f-6 -0.00013969396 2.9727044f-6 0.00026205266 8.067256f-5 6.702331f-5 6.559517f-5; -3.814729f-5 1.3593758f-5 9.767804f-5 0.00016718956 -2.927444f-6 -0.0001240738 4.994642f-5 9.826624f-5 -0.00011068921 -0.0001496352 3.8343564f-6 8.562004f-5 5.930099f-5 -0.00014168234 -0.000100030324 6.151323f-5 0.00013858969 -3.506008f-6 -4.034564f-5 -5.314173f-5 -0.000121211124 -0.00013362279 4.779996f-5 -5.363289f-5 3.2582093f-6 5.9207483f-5 -9.091286f-5 -1.3527549f-5 1.0023946f-5 1.7738491f-6 -0.00013730755 5.106847f-5; -3.8029808f-5 -7.1578165f-6 -5.9905127f-5 -0.00011579393 -5.9651964f-5 0.00014234427 -0.0001265084 0.00013457597 -2.643115f-5 5.1869065f-5 4.13519f-5 5.0089737f-5 2.2383385f-5 -2.1015903f-5 -3.3000528f-5 -7.047315f-5 -8.8368484f-5 -0.000103409 -2.2513545f-5 9.545324f-5 0.00016326063 1.4866504f-5 -2.7461361f-5 0.00013655578 -2.8484184f-5 0.00011732666 2.7697592f-5 0.00011522831 0.00011211378 -4.8699552f-5 -0.00011030362 -0.0001299399; -1.3283519f-5 2.6671209f-5 -5.7626003f-5 -2.3251074f-5 -5.5325505f-5 -0.00010743445 0.00012681684 8.199734f-5 0.000112731774 4.9428458f-5 -0.00010794658 -3.663404f-5 -4.9727714f-5 0.00010339602 -5.1617666f-5 -1.1320217f-5 0.00010763656 8.212104f-5 6.619241f-5 3.0835257f-5 -4.2610827f-6 0.00014391147 0.00012521925 8.927748f-5 -6.473973f-5 -0.00014732472 0.00027566554 -0.00012812448 7.944852f-5 -0.00011799917 -4.981637f-5 2.385294f-5; 6.7581577f-6 -3.111062f-5 0.00016292097 -0.00023927882 2.8827384f-5 -5.0622937f-5 9.843347f-5 5.1365438f-5 7.399076f-5 -5.885627f-6 5.4597458f-5 0.00014391745 2.772442f-5 8.632226f-5 -3.5797526f-5 -0.00011051477 -0.00010409407 0.00010507909 1.4763801f-5 7.6786266f-5 -0.00016620259 9.4946634f-5 -1.6403084f-5 7.07205f-5 6.0242974f-6 -0.00023107119 0.00017417199 0.00017030525 -2.6234244f-5 -8.8664725f-5 -7.557051f-5 -0.00016278109; 5.725597f-6 -0.00020547755 3.360069f-5 6.226016f-5 0.00014226133 -5.144424f-5 7.7108896f-5 -0.00010480938 -1.7489569f-5 7.3620366f-5 0.00025641598 -0.00014535265 1.3358166f-5 -2.2990982f-5 0.00011502761 9.355861f-5 9.212192f-5 -1.263889f-5 -0.00010474833 -6.6318506f-5 0.00011348067 5.68082f-5 -0.00012675693 4.1951458f-5 -2.136407f-5 6.793268f-5 -0.00010778936 -4.427508f-5 0.00029736006 5.852896f-5 -8.518875f-5 -0.00016620646; -1.5460433f-5 6.495008f-5 -0.00011498083 0.00011408379 1.6533008f-5 1.24117205f-5 5.8968362f-5 5.7565616f-5 -3.8454356f-5 -2.6930395f-5 0.0001033992 0.00018925226 -0.00013935984 8.140396f-5 6.92721f-5 0.00014421358 3.2653847f-6 0.00014279221 3.391793f-5 -9.716646f-5 6.618853f-5 1.1925326f-5 7.769247f-6 6.7216356f-6 -0.00021969838 7.1532646f-5 2.019375f-5 -0.00017252813 4.24603f-5 -3.4014938f-5 0.00018951505 -8.512112f-5; -0.00010743505 -0.0001131201 -1.7973487f-5 -0.00016898938 -0.00012774325 -6.628909f-5 -3.0196969f-5 -9.745659f-5 -0.0001887618 8.825126f-5 -0.00011981473 3.8243717f-5 -0.00018173826 -4.1210706f-6 -5.2959942f-5 0.00013842799 -0.00013723392 -7.26531f-5 -0.00012161047 -0.00011602334 -0.00011089694 2.0081594f-5 0.00013122075 -6.1367675f-5 -0.00010137981 2.9417248f-5 -0.00013423998 -0.00017916219 -0.00023853774 8.976395f-5 6.6309534f-5 9.967746f-5; -7.884861f-5 -0.00013870894 -2.197112f-6 -0.00015850442 2.692798f-6 -3.2945143f-5 -8.191783f-6 -9.872612f-5 -1.132366f-5 7.13805f-5 9.606281f-5 -3.0684958f-5 -8.823566f-5 0.0001342318 -8.097449f-6 -2.014796f-5 1.0498247f-5 9.3265204f-5 1.5282098f-5 -4.759951f-6 -5.3235388f-5 -0.00012016519 -2.244303f-5 5.157509f-5 -0.00010684524 0.00012923192 8.022625f-5 0.00021072496 5.608303f-6 1.773952f-5 6.461307f-5 9.768896f-5; -6.5789034f-5 -4.146357f-6 6.9265625f-5 7.3675687f-6 3.0015435f-5 1.7167002f-5 3.118201f-5 -9.935222f-6 -0.00022721611 8.785196f-5 3.860361f-5 4.6925205f-5 8.9227426f-5 -0.00015880901 2.7338643f-5 -8.987243f-6 3.1726035f-5 -8.965283f-5 -0.000106385094 4.820085f-5 -0.0001371696 5.0909f-6 -0.00025116967 -4.658938f-5 8.115535f-5 -6.672927f-5 0.00011936258 -7.316819f-5 0.00018355444 -1.8726627f-5 0.0001276123 -7.482219f-5; 0.00015722858 1.2288494f-5 -0.00016296927 8.4937776f-5 -7.1816394f-5 1.8298397f-5 9.112014f-6 -0.00016317195 -2.4868868f-5 1.1893514f-5 -1.2976632f-6 5.112342f-5 2.1192216f-5 1.8121884f-5 -6.412158f-5 0.00013386353 6.286068f-5 4.364605f-5 -0.0001315409 5.6048316f-6 0.00017213597 0.00017970448 -0.00016307805 7.8855235f-5 4.6501475f-5 9.9515666f-5 -8.046149f-5 -7.376895f-5 -7.654707f-5 1.007736f-5 -7.691511f-5 1.4484459f-5; -7.814054f-5 -2.9948533f-5 2.4538089f-5 1.0556856f-5 8.3544495f-5 0.00015973496 1.1830727f-5 2.510352f-5 9.1612994f-5 4.7756042f-5 -0.00013259728 -0.00018799082 -3.6519345f-5 6.506085f-5 -9.951103f-5 0.00019554085 0.00011505081 9.675667f-6 -9.753885f-5 -9.7618635f-5 -9.343804f-5 -5.1292467f-5 -0.00010487465 0.00018821807 -0.00011029698 -0.00010334422 0.00010081205 3.1622527f-5 3.6296286f-5 -6.1224615f-5 5.574921f-5 9.516984f-6; 2.6665432f-5 0.00011848651 -0.000101717334 0.00013458653 7.4107297f-6 0.0002462207 3.2440385f-6 0.00017600963 2.9328632f-5 4.88874f-5 0.00014498943 5.8251935f-5 9.206696f-5 6.0409737f-5 -6.98849f-5 0.00016194706 -1.9337413f-5 -7.481019f-5 -6.4351566f-6 -0.00022975693 -5.0296876f-5 -3.2921533f-5 -6.848007f-5 -3.9285453f-5 -6.7002795f-5 0.00015718317 -0.00018342034 -2.2822232f-5 3.1187722f-6 -2.569619f-5 0.00011796314 0.0001335592; -0.00015694519 -5.7720907f-5 -6.252055f-5 -4.061748f-5 -9.370695f-5 0.0001546027 8.985626f-5 -3.166662f-5 -7.9722784f-5 -4.9501483f-5 0.00018648649 -0.00010456351 -3.3554883f-5 -0.00014471065 -2.5343248f-5 6.9533606f-5 9.512717f-7 3.044426f-5 2.0006306f-5 -0.00012615162 -0.00010747843 5.9736863f-5 -3.695839f-5 -9.307604f-5 -9.7304786f-5 9.516982f-5 -2.7774136f-5 -0.0001705351 -9.098285f-5 6.3138454f-5 1.2963185f-6 -0.00011880772; 0.00017649402 -0.0001231994 6.9216774f-5 -1.0872204f-5 -0.00024971666 -0.00014553945 -0.0001144602 5.663581f-5 -4.918977f-5 -0.00011933162 -5.0467017f-5 0.00011450008 0.0001028808 6.554514f-5 -0.00012416198 -8.421403f-5 -1.2993931f-5 5.372237f-5 -4.1466483f-5 -1.2192487f-5 -3.930576f-5 7.924142f-5 -7.524426f-5 -1.4768968f-5 2.1836682f-5 -0.0001541078 4.063052f-6 0.00018581949 1.7669237f-5 -3.3115124f-5 7.945017f-5 8.326303f-6; -9.9303776f-5 -3.6604877f-5 -4.332639f-6 -0.0001686118 -1.0452413f-5 3.225742f-5 9.676696f-5 -9.350344f-5 -7.287202f-5 -2.294634f-5 0.00017191475 3.6592723f-5 7.266302f-5 4.5688645f-5 -0.00011848715 9.991739f-5 -1.4909016f-5 6.189788f-5 -0.00013154054 8.521616f-5 4.530149f-5 5.9865528f-5 -0.00010893749 9.21968f-5 4.402815f-5 0.00023545643 5.726673f-5 7.653811f-5 -1.8887335f-5 5.9024827f-5 -1.0477036f-6 7.440681f-5; 0.0001222087 -0.00014564916 0.0001371978 -7.085177f-5 6.37895f-5 -4.75666f-6 -0.00025115532 -2.0881158f-5 -0.00011426949 -1.2586963f-5 1.3328574f-7 0.000109954846 8.8413035f-6 -6.034521f-5 5.6743553f-5 -1.0812556f-5 -0.00019796647 -5.181229f-5 -0.000112028145 5.501289f-5 1.6939339f-5 0.000113612645 -0.00014973126 -4.1655472f-5 3.76164f-5 0.00014467904 -9.70764f-6 8.452218f-5 -8.556286f-5 5.1908377f-5 0.00010573853 2.217513f-5; 0.00019297206 1.990218f-5 -8.3874314f-5 -0.00010467616 0.00014412554 9.714395f-6 -1.5407923f-5 -0.00026089133 -0.00013323143 0.0003160697 2.5921152f-5 4.7926533f-5 4.6341687f-5 0.00015741738 6.91684f-5 -0.00013063801 -7.7531426f-5 -3.63283f-5 -2.3256646f-6 -0.00010651343 -1.49681055f-5 0.00022475966 -6.339429f-5 -8.616626f-6 -7.7798846f-5 0.00016889207 4.7387788f-5 -4.7063262f-5 5.666139f-5 3.650029f-5 -6.376138f-5 -0.0001768277], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[0.00013918293 -4.023764f-5 0.00018670673 -3.7803573f-5 -3.58381f-5 -0.00022859777 3.159105f-5 5.4515465f-5 1.3109553f-5 1.5337984f-5 1.0644961f-5 -3.351607f-5 -5.7199544f-5 8.08562f-5 -0.00018813655 -9.083517f-5 -0.00020394672 -0.00015878466 4.6021087f-5 5.878495f-5 0.00016628442 1.9146371f-5 9.00028f-5 7.61314f-5 7.419595f-5 6.5806285f-5 -8.7520556f-7 8.9556976f-5 -5.1078634f-5 0.00015786105 0.00014618358 -3.7175416f-6; 3.7177553f-5 0.0001565449 5.0861046f-5 -0.00012494574 -2.0561192f-5 -1.0600153f-6 2.374432f-5 -0.00013837003 0.00014373848 -0.00011294741 -1.0434356f-5 -2.1967535f-6 3.2850166f-5 4.3804052f-6 -6.270678f-5 5.3447617f-5 0.0001442197 0.00011172094 -2.5692656f-5 4.996458f-5 -0.00015435033 0.00014785261 1.5227033f-5 0.00010042882 -6.333131f-5 0.000171395 -2.7840246f-5 0.00015208795 -4.2221967f-5 -6.4788546f-5 0.00014770783 8.7698114f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))Similar to most DL frameworks, Lux defaults to using Float32, however, in this case we need Float64
const params = ComponentArray(ps |> f64)
const nn_model = StatefulLuxLayer{true}(nn, nothing, st)StatefulLuxLayer{true}(
Chain(
layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
layer_2 = Dense(1 => 32, cos), # 64 parameters
layer_3 = Dense(32 => 32, cos), # 1_056 parameters
layer_4 = Dense(32 => 2), # 66 parameters
),
) # Total: 1_186 parameters,
# plus 0 states.Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses
where,
function ODE_model(u, nn_params, t)
χ, ϕ = u
p, M, e = ode_model_params
# In this example we know that `st` is am empty NamedTuple hence we can safely ignore
# it, however, in general, we should use `st` to store the state of the neural network.
y = 1 .+ nn_model([first(u)], nn_params)
numer = (1 + e * cos(χ))^2
denom = M * (p^(3 / 2))
χ̇ = (numer / denom) * y[1]
ϕ̇ = (numer / denom) * y[2]
return [χ̇, ϕ̇]
endODE_model (generic function with 1 method)Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)
axislegend(ax, [[l1, s1], [l2, s2]],
["Waveform Data", "Waveform Neural Net (Untrained)"]; position=:lb)
fig
endSetting Up for Training the Neural Network
Next, we define the objective (loss) function to be minimized when training the neural differential equations.
const mseloss = MSELoss()
function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
return mseloss(pred_waveform, waveform)
endloss (generic function with 1 method)Warmup the loss function
loss(params)0.0007520900584586484Now let us define a callback function to store the loss over time
const losses = Float64[]
function callback(θ, l)
push!(losses, l)
@printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
return false
endcallback (generic function with 1 method)Training the Neural Network
Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
optprob, BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
callback, maxiters=1000)retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [3.1807685445497415e-5; -7.478387124124551e-5; -0.00013932163710705632; -0.00019942939979929162; -5.125871030025476e-5; -2.9746048312466074e-5; -7.506302063117644e-5; -8.228870865421423e-5; 2.6307439838969432e-5; 2.3313152269095585e-5; -7.98030960139233e-5; -7.928155537220577e-5; -3.0228102332304186e-5; -7.013693539184675e-5; 0.00017441720410700303; -6.931676034574505e-5; -4.95325948576758e-5; -0.0001096411433535617; 0.0001066019831340761; -0.0002734832523855284; -0.00016516700270572404; 1.533169597677929e-5; 1.4763273611621293e-5; -0.00010183243284685091; 7.273099981832357e-6; -6.982019112904687e-5; 0.00019025181245500725; 0.00016863248310963747; 1.9166856873163088e-5; 2.3672128008938745e-5; -0.0001529027795184814; -0.0001396689185640479;;], bias = [3.329394530334202e-17, -1.453514617057689e-16, -2.0146539053931225e-18, -2.8421226495981545e-16, -6.715480785537107e-17, -5.692528622948356e-17, -3.168355674107845e-17, -1.780804440990402e-17, 3.438388436917298e-17, 1.7040449386025725e-17, -1.6478723316828804e-16, -9.166743863825677e-17, -2.7667852818261396e-17, -8.83358589014404e-17, 1.6836461941605557e-16, -9.124579678300221e-17, -6.608454110309458e-17, -9.971297286811983e-17, 1.0042355287436589e-16, -1.0642760424226255e-16, -1.4224829464157743e-16, 4.739557332838755e-18, -1.6046827595393604e-17, -1.6582400867889878e-16, -2.2168607470764203e-18, -7.004892174692208e-17, 2.5442746462961623e-16, 2.0642626549004426e-16, 3.812054751673557e-17, -2.94277607487978e-17, -1.5967879649778117e-16, 6.869457264690211e-17]), layer_3 = (weight = [-9.418567361623923e-6 1.6964028572057293e-5 2.4875259140477473e-5 -9.983920729711034e-5 2.332818500590509e-5 0.0001742632872891618 -8.205966271916506e-5 0.00014668845566930108 6.173243188541446e-5 -2.1661478359507646e-5 -0.00019903155164783945 0.00011555725965890908 -0.0001797567395725282 -8.918383658155878e-5 -3.698658231870271e-5 7.087809960367083e-5 -2.2030322116822894e-5 3.648705994137246e-5 -6.943530325862261e-6 1.815608690674329e-5 -2.99759734217393e-5 6.826670748423865e-5 7.378383331243028e-5 -1.7581283656004128e-6 0.00022217746179094622 -4.549759372011835e-5 -7.137757468727612e-5 4.835105452811412e-5 0.00010871179135292843 9.797122101102229e-6 4.566907470588807e-5 -0.0001233360392971952; 8.039712799778023e-5 0.00015739681209502688 -5.10989947785125e-5 -7.299892506728914e-5 -5.035350743732555e-5 4.2668704803286085e-5 0.00014580808881269714 8.179642561422637e-5 9.029342715926275e-5 -3.182300399807481e-5 -4.398922003624065e-5 -1.1045291036356156e-5 0.0001249663119536806 -0.0001786017154942274 2.224303996023426e-5 -7.30699165857405e-5 -0.0001534191168745063 -9.078136547636121e-5 3.2522050854238794e-5 -0.00013883422700621177 3.162068339712535e-5 -0.0001494933157361133 -0.00013814998140027952 0.00012135856475047842 -6.846851367904085e-5 7.807964818617251e-5 -6.513696910281801e-5 0.00013601635243809429 -4.985042711855782e-6 -8.993957568823476e-6 -6.29846389771395e-5 3.7513885284497915e-5; -0.00010560809535201571 0.00011796009418793001 -3.3381428535741876e-5 -0.00020257425133133998 -0.0001587850936874184 -7.311415471865762e-6 -0.00013376952195068843 4.405804354353765e-5 -5.7334198816695e-6 -0.00025403807451477184 3.265690172445925e-5 0.00011062290211446548 -2.8667570247141847e-5 -3.6538688461240445e-5 7.308722631312396e-5 0.00012937008357146747 0.00018241139257332357 0.00039527886035391923 0.00020417618286335708 -5.896365548945486e-5 -0.00018904168321280258 -4.439358674682592e-5 6.236271944834004e-5 -6.350491533050428e-5 -9.037159982364535e-5 0.00014755460795020677 6.19148478773919e-5 0.00020010790392295918 -8.995858736563392e-5 2.3628893516781605e-5 4.492487202985773e-5 -0.00018283073661843607; -9.776280394469069e-5 7.571337873635086e-5 -0.00024980787175188916 -8.394689548475388e-5 -5.318176017787961e-5 2.0645466337398117e-5 2.9031698703716677e-5 4.626161926246422e-5 4.829155686714082e-5 0.0001783923183200635 8.745161567169428e-5 0.00014963338280615582 -4.834655164920404e-5 8.725744281788236e-6 -7.922574486763635e-5 1.6249365488716775e-5 3.442394545238747e-5 -2.8978440251883966e-6 -6.150546296446524e-5 -7.497242377320903e-5 0.00015170915530581626 -9.245931478396652e-5 0.00012715547591035594 7.625753305439922e-5 -1.0287327951275575e-5 -2.4067817959052153e-5 -0.00016119981928675996 6.624214101868313e-5 0.00016434210947302507 -0.00018287645407433288 9.726905072859245e-5 3.333331214782638e-5; 3.043668407742096e-5 -7.968011182978287e-6 4.1270515843332516e-5 0.00015086120406890155 6.136669985978004e-5 -7.497793553368952e-5 0.00011393454774776378 -6.889850914916183e-5 -0.00010179905070429488 -4.275169335044734e-5 -3.1295821335645995e-6 -0.00013250274054911118 3.573616782347838e-5 1.97929185142735e-5 0.0001152484183458464 7.112600729130159e-5 7.657551764538595e-5 -0.000247490490388316 -4.744450047468249e-5 -0.00018119522423692267 -6.449737697580267e-5 0.00017699961613659018 -1.5405963445965012e-5 1.2345461122152531e-5 1.4915853220436078e-5 0.00012619508202448435 0.0001714774553457095 -1.4429839162355891e-5 0.00012937879369389463 -4.849246028782613e-5 6.981089978040043e-5 5.297608009100955e-5; -0.00011009113051647159 -0.0001639308448992448 2.339966329795312e-5 6.099418704448193e-5 -0.00011588363678844413 -9.540041912065661e-5 0.0001201610083463824 0.00016454610329868678 -1.6795920708196588e-5 -0.0001264569274261393 0.00012927096566466999 -8.684532633002461e-5 6.154457685819055e-5 2.7417870917030163e-5 -0.00010456441497613211 4.9156927963973466e-5 -8.93194138533227e-5 0.0001930835004551585 -0.00012879937832427544 9.70605850690556e-5 -1.3813825824266694e-5 -6.583241894792635e-5 -0.00014810774711610617 3.724900114515421e-5 0.0002228708925778245 0.00010707805075265794 6.341484721101597e-5 -0.0002172286750627364 -0.0002479564701551371 -3.601158288608158e-5 3.3903803234483874e-5 5.158576266889468e-5; -7.262259772583287e-5 -0.00011800181404292435 -3.2783524728865916e-5 0.00017161164702075436 3.0093654787511966e-5 1.6961753645908178e-5 7.641729585988592e-5 0.00014773535927919715 -4.6997628159126004e-5 8.78709482650705e-5 -3.3230992484147796e-5 0.00010240223615359601 1.5249757199078863e-5 7.039819012010465e-6 0.0001481547745798881 0.0001148457607643142 5.252271295054442e-6 8.64172703689938e-6 7.340200100896029e-6 -1.5764990747742825e-5 1.0821230363237041e-6 -2.7892505967294127e-5 -3.6800169836034315e-5 0.0001980147584264511 -5.9473887303148495e-5 0.00023849147815885232 -3.915556466280263e-5 0.000139575809372328 4.946546447276752e-5 -0.0002867591016537078 -7.172229255854654e-5 -5.9443473800338515e-5; -6.892226772204088e-5 6.58869996816448e-6 7.21271910396536e-5 0.00011449220367186529 4.5510515726096034e-5 0.00019023481329616114 2.0196740685536675e-5 -3.0541511271691106e-5 0.00019477805219769438 4.968141841066038e-5 -7.975759469725383e-5 -1.6204764299869568e-5 -0.00014199104295864747 -0.00011493220770087247 5.4832756383727203e-5 0.00016608003100093303 -0.00014224818985264834 -0.00011053930463556548 -3.081105639746741e-5 -9.169998045135487e-5 -4.9726192422060475e-5 5.802836064436466e-6 -3.748000633374811e-5 0.00016072480978157068 8.080146401600893e-5 0.0002139856053208172 0.00011721582016154116 -6.950530475758811e-5 5.8748756807222685e-5 -0.00019908992069128935 -8.953440080684335e-5 -7.581179195138245e-5; -4.840570015642123e-5 -8.262426688144796e-5 -4.493418070761152e-5 7.507939319785849e-5 8.465410089573005e-6 -5.371408239864086e-5 0.00012294532810063223 -4.29825505967696e-5 -4.9435430057768806e-5 -8.16451267133963e-5 -0.0001385693669113778 -8.237997660455118e-5 -0.00012578911825823866 0.00015829533533849815 -2.954224343975221e-5 -7.53189143985039e-5 2.6364192151393243e-5 -0.00012700813219690974 7.669092322192752e-5 0.0001401370626148889 8.883107677993423e-5 -2.080197570887721e-5 -0.00011748826925452197 3.415703561348425e-5 -4.3890601020961706e-5 -0.00012422525303247325 -6.41348580507761e-5 0.000130960610319927 3.2443125589794126e-5 2.133049368001221e-5 -0.000210011951069285 2.197876149456826e-5; -1.5825969763606327e-5 -6.489074237462186e-5 -6.731445116738949e-5 -0.00012846242114018523 6.218224937045715e-5 -0.00023130965362353585 2.6734285020997855e-6 6.519858491091076e-5 7.355798265128435e-6 0.00014972495612172048 -0.00021586086020389575 6.942125967187795e-5 -1.248763853792899e-5 4.5413455358435393e-5 -5.484861119466335e-5 -3.129519223797667e-6 0.00017250928269648688 -0.00015400075587959628 -6.861483214007399e-5 6.335427517490819e-5 -7.945477934836098e-5 -0.00011940858328299968 0.0001270640734438374 -1.5565428635373246e-5 -9.199026478737976e-5 -1.3629947512448787e-5 -5.310533540604723e-5 7.230515823184894e-5 4.1725752657005634e-5 0.00011630072024134567 -0.0002695426572026026 0.00021775980187530478; -5.544734653609736e-5 -0.00011675311383965672 -0.0002526190203259258 -4.89817232352897e-5 3.6658298771943025e-5 -2.8136239275922484e-5 2.902900064713548e-5 9.737836240150879e-5 0.00013819967148650462 -9.836622325180093e-5 -0.00020601689015689449 -0.00015141340395144338 3.217706921917442e-5 -8.790287233365105e-5 0.00010430252746507836 -0.0002818398481811153 -0.00016177059864800158 2.1756329754243384e-6 -0.0001338788554182875 -0.000123042669879992 2.848446434134037e-5 9.506724177684598e-5 0.00022340785813342742 3.176885889320102e-5 0.00013302195452909857 -4.346664922719231e-5 -0.0001163051913369711 -9.969941151732537e-5 0.0002527080812380834 -0.0002005813423654809 1.3216360234099552e-5 -2.2145661428782613e-5; 2.0380423880177203e-5 -5.82618258383668e-5 0.00011736042964480156 -4.092020659733569e-5 -0.00015870671174047176 0.00015799364794953234 -3.8175893646207e-5 2.09369236793052e-5 5.024366058501863e-5 -0.00013056613076036338 9.762942713909982e-6 3.729553739981844e-5 9.614454487314886e-5 5.0932348163092635e-5 3.766606917928523e-5 -3.9981025988476164e-5 0.0001476469888396037 -4.176735696772706e-6 -1.542408849443977e-6 0.0001239625972040907 7.816841847237934e-5 -3.99704394701448e-5 -0.00013572132135299486 0.00010012157602909828 1.7663630419755566e-5 -3.532189927204311e-5 5.777618665455116e-5 -0.00013851811214841235 2.1447823233082926e-5 -1.8021887305996742e-6 8.281174544685559e-5 4.2585875267146934e-5; 0.0001189113287230196 -9.246661956671705e-5 5.730959496992203e-6 0.00022869459335675706 0.00010369823995029445 -8.669562102545179e-5 0.00012529985961487434 -0.00012441131129412055 6.6285324105824e-5 0.00015123275498377028 -0.00010571911424880919 -1.930820271331426e-5 -4.559099342690907e-5 -7.015985774591652e-5 2.1628308427586822e-5 8.640144738141952e-6 4.1252271880925445e-5 -0.00015185471001563617 3.2991924495268354e-5 -0.000133904558783559 8.18495869728004e-5 -9.524073847819253e-5 -8.360321534814446e-5 -0.00011921310524007341 8.596734242496878e-5 0.00017634288506565525 -0.0002042114387174861 5.444992039452441e-5 -0.00015572931751625747 1.3416197002929274e-5 -6.907651131290856e-5 0.00016909379028698482; 4.962326064211289e-6 0.00013464171225533586 -0.0001082149463323733 6.822577434684853e-5 5.1595329030910566e-5 1.950544363732061e-6 0.00010666124599848633 -1.8373735466947382e-5 6.674710330882274e-5 5.2905849847017155e-6 0.00011460817787193394 4.154509243418169e-5 8.27628902774598e-5 0.00011978962739559639 6.398382557416867e-5 -2.290104729726151e-5 -2.5299324799184492e-5 0.0001432244464089366 -4.760393475552796e-5 8.132418959894585e-6 -0.00010679027017577693 -5.466401834526831e-5 5.0133520024625747e-5 6.407257042920522e-5 -9.278541237252549e-5 -0.000215012579966776 -4.035935107059713e-5 5.1281684326014914e-5 -9.569520254065715e-6 -0.00010937126334427684 0.00017354887696310124 0.00010336088617659669; 0.00013515212659614113 0.0001293074080523259 -0.0001267934942815374 -6.673042052071885e-5 0.00012424016923533613 0.00026082530413178677 1.0779801515202788e-6 -0.00021562763985356686 -6.230937405133652e-5 0.00014169241213595472 -5.485985574422754e-5 2.2609415019458432e-5 -1.4976278442999347e-5 -0.00015079018456971543 0.00011259246704057034 4.2756308822489515e-5 -5.103103948285004e-5 -0.00012996412377915077 -1.0539085714203994e-5 7.088113371634427e-5 -0.00013234343467857343 3.7822099976539523e-5 -5.06612207472139e-5 3.0911948407361146e-5 4.9179640619615106e-5 -3.4281013439715525e-6 -0.000139692239180395 2.974427580824651e-6 0.0002620543879958054 8.067428162677188e-5 6.702503625192316e-5 6.559689673949816e-5; -3.814877938594936e-5 1.359226778703462e-5 9.76765527227117e-5 0.00016718806862600814 -2.928934254952621e-6 -0.00012407529586226955 4.994492835948199e-5 9.826474841219382e-5 -0.00011069069889373428 -0.00014963669538730118 3.832866222810604e-6 8.561855200213368e-5 5.92994997347771e-5 -0.0001416838263324377 -0.00010003181389065246 6.151174009545622e-5 0.00013858820009744527 -3.5074981681708584e-6 -4.0347129943536154e-5 -5.31432219943542e-5 -0.00012121261399470205 -0.0001336242771870394 4.7798469931591984e-5 -5.36343818751518e-5 3.2567190643811024e-6 5.920599276550033e-5 -9.091434778410448e-5 -1.3529038842805863e-5 1.0022455607838276e-5 1.7723589095904265e-6 -0.00013730904040157917 5.1066979249023526e-5; -3.802937864259141e-5 -7.1573872579631325e-6 -5.9904698004746356e-5 -0.00011579350395061872 -5.96515346975373e-5 0.00014234470289500173 -0.00012650796999389628 0.00013457639578098095 -2.643072004013947e-5 5.1869494186916684e-5 4.135233028625737e-5 5.009016602832493e-5 2.2383814198632966e-5 -2.1015473989451622e-5 -3.300009849647727e-5 -7.047272084912594e-5 -8.836805501653158e-5 -0.00010340857262656502 -2.2513115267999755e-5 9.545366627695617e-5 0.00016326105831955868 1.4866933256939384e-5 -2.7460931973678652e-5 0.00013655621295163023 -2.8483755172260674e-5 0.00011732708661418098 2.7698020844917722e-5 0.00011522873681803527 0.00011211421222958467 -4.8699123012398074e-5 -0.00011030319002623309 -0.00012993947167581465; -1.3281517048934172e-5 2.6673210932302417e-5 -5.762400035548255e-5 -2.3249071482347003e-5 -5.532350262893438e-5 -0.00010743244495625566 0.0001268188439736954 8.199934524799354e-5 0.00011273377601373062 4.943046003080965e-5 -0.0001079445777848488 -3.6632036817365635e-5 -4.972571189665199e-5 0.00010339802390279954 -5.161566371482809e-5 -1.1318214483081656e-5 0.0001076385611397115 8.212304380338887e-5 6.619441553101233e-5 3.083725966474876e-5 -4.259080405815468e-6 0.00014391347355757918 0.00012522124740843742 8.927947926184519e-5 -6.473772958070877e-5 -0.0001473227225049514 0.00027566754722117067 -0.00012812247444328738 7.945051908808834e-5 -0.00011799716451588025 -4.9814369440195376e-5 2.3854941510015586e-5; 6.758590845224296e-6 -3.111018648114275e-5 0.00016292140591913428 -0.00023927838244917562 2.8827817206278532e-5 -5.062250400287252e-5 9.843390236865508e-5 5.136587080569703e-5 7.399119587872026e-5 -5.885194039633818e-6 5.459789119568424e-5 0.00014391788519942905 2.7724852972555307e-5 8.632269257005093e-5 -3.579709280375022e-5 -0.00011051433446327433 -0.00010409363656280738 0.00010507952194684809 1.4764234239649302e-5 7.678669882961304e-5 -0.00016620215663197037 9.494706702908355e-5 -1.6402651252333582e-5 7.072093669349043e-5 6.0247304843002725e-6 -0.0002310707530144064 0.00017417242290475695 0.00017030568799027642 -2.6233810584285824e-5 -8.866429177787305e-5 -7.55700782777229e-5 -0.00016278065211581475; 5.726927713806661e-6 -0.0002054762234249566 3.360201945975352e-5 6.226148899508632e-5 0.00014226265840330804 -5.1442909250214955e-5 7.711022691433834e-5 -0.00010480805293558627 -1.7488237869871955e-5 7.362169627664415e-5 0.00025641731109568404 -0.0001453513185763217 1.3359496836313243e-5 -2.298965126759757e-5 0.0001150289382283056 9.355993744317065e-5 9.21232486805478e-5 -1.2637559745819199e-5 -0.00010474700037524764 -6.63171748667733e-5 0.00011348200415591302 5.680953028128297e-5 -0.00012675560072897184 4.195278880866558e-5 -2.1362739869106465e-5 6.793400745005897e-5 -0.00010778803051585188 -4.4273748141904264e-5 0.00029736139468947235 5.8530290619058404e-5 -8.51874218708394e-5 -0.00016620512986923643; -1.5458005472169914e-5 6.495250749241701e-5 -0.00011497840274286453 0.0001140862198018033 1.6535435171247416e-5 1.2414148151150808e-5 5.897078994268687e-5 5.756804352233234e-5 -3.8451928738937205e-5 -2.6927966936341784e-5 0.0001034016288405106 0.00018925468361077912 -0.00013935740822387357 8.140638714496237e-5 6.927452817862939e-5 0.00014421601121386712 3.2678124090333895e-6 0.00014279463834201563 3.3920359093022265e-5 -9.716403543183193e-5 6.619095551386802e-5 1.1927754022621952e-5 7.77167428270769e-6 6.724063260768342e-6 -0.00021969595651069091 7.153507362192691e-5 2.01961768117092e-5 -0.0001725257050192332 4.246272623379325e-5 -3.401251047660069e-5 0.00018951747664787503 -8.511869487746622e-5; -0.00010744268941251028 -0.0001131277370664684 -1.798112528243167e-5 -0.000168997021617975 -0.0001277508839198854 -6.62967270987617e-5 -3.020460674503546e-5 -9.746423013763416e-5 -0.00018876943643461458 8.824362309625791e-5 -0.00011982236749514947 3.823607919133234e-5 -0.00018174589634198187 -4.128708795696928e-6 -5.296758020313143e-5 0.00013842035120563184 -0.0001372415557203353 -7.266073801342884e-5 -0.00012161811021410576 -0.00011603097512159445 -0.00011090457549335892 2.0073955667082842e-5 0.00013121310773514546 -6.137531302425039e-5 -0.00010138744831109124 2.940960955259476e-5 -0.0001342476155773377 -0.00017916982899669732 -0.00023854538297071683 8.97563092361619e-5 6.630189606901703e-5 9.966982331340626e-5; -7.884819939651925e-5 -0.0001387085358743881 -2.1967046592873374e-6 -0.000158504015476573 2.6932054282281226e-6 -3.2944735781414686e-5 -8.191375914679308e-6 -9.872571542080032e-5 -1.1323252752548305e-5 7.13809058054963e-5 9.606321685562374e-5 -3.068455049803948e-5 -8.823525428458514e-5 0.0001342322065596144 -8.09704130522525e-6 -2.0147552336924535e-5 1.0498654714252815e-5 9.32656111529695e-5 1.5282504986612375e-5 -4.759543631447819e-6 -5.323498050136302e-5 -0.00012016478291632064 -2.2442622199081048e-5 5.157549815915003e-5 -0.0001068448328311873 0.00012923232855821507 8.022665855089276e-5 0.0002107253675915877 5.608710520688497e-6 1.7739926928994973e-5 6.461347942088038e-5 9.768936427856155e-5; -6.579027910094165e-5 -4.147601763280998e-6 6.92643799392448e-5 7.366323765293889e-6 3.0014190209737532e-5 1.716575673880128e-5 3.11807663218936e-5 -9.936466873215231e-6 -0.00022721735228161728 8.785071180126223e-5 3.860236677964044e-5 4.692395953851574e-5 8.922618065652044e-5 -0.0001588102539845859 2.7337398145299816e-5 -8.98848781263227e-6 3.1724789672704834e-5 -8.965407402330027e-5 -0.00010638633909012166 4.8199605893313396e-5 -0.0001371708429961492 5.089654799927255e-6 -0.0002511709106930667 -4.659062402980416e-5 8.115410231210795e-5 -6.673051472363297e-5 0.00011936133444933307 -7.316943162195609e-5 0.0001835531940140627 -1.8727871853219684e-5 0.00012761104802293274 -7.482343141064629e-5; 0.0001572291776318522 1.2289087891875815e-5 -0.00016296867667563145 8.493836961121959e-5 -7.18157998272733e-5 1.8298990682414745e-5 9.112607614841818e-6 -0.00016317135575092976 -2.4868274493310555e-5 1.189410798531732e-5 -1.297069246103503e-6 5.112401507347255e-5 2.119281001057942e-5 1.8122477769683296e-5 -6.412098716111947e-5 0.0001338641277760679 6.286127339251778e-5 4.36466444006122e-5 -0.00013154030872840168 5.605425514535697e-6 0.0001721365670454126 0.00017970507636335227 -0.00016307745224195983 7.885582911683533e-5 4.650206850044913e-5 9.95162600082757e-5 -8.046089621359686e-5 -7.376835398474764e-5 -7.654647467290505e-5 1.0077953476756886e-5 -7.691451443690767e-5 1.448505284486219e-5; -7.814063171029519e-5 -2.9948624767013323e-5 2.4537996826228666e-5 1.0556764057373282e-5 8.354440327163363e-5 0.00015973486991494466 1.18306350784847e-5 2.5103427863331113e-5 9.161290184489953e-5 4.7755950465915146e-5 -0.00013259737640258337 -0.00018799091518961305 -3.651943658699711e-5 6.50607569994177e-5 -9.95111233053018e-5 0.00019554075955308556 0.00011505071855278263 9.67557464970811e-6 -9.753893861418988e-5 -9.761872676538514e-5 -9.343812886730188e-5 -5.1292559116729196e-5 -0.0001048747409797503 0.00018821797387569742 -0.00011029707563214434 -0.0001033443087792252 0.00010081196054013123 3.162243486510241e-5 3.62961935322968e-5 -6.122470692137686e-5 5.574911741202334e-5 9.516892380605903e-6; 2.666845838793996e-5 0.0001184895373557042 -0.00010171430759111832 0.00013458955808560768 7.413756604727183e-6 0.0002462237317489785 3.2470653353770524e-6 0.00017601266162491284 2.933165896356579e-5 4.889042619870532e-5 0.00014499246034792545 5.8254961997949746e-5 9.206998956178764e-5 6.0412763643759e-5 -6.988187661375471e-5 0.0001619500836499322 -1.9334385713228726e-5 -7.480716149767557e-5 -6.432129710092949e-6 -0.00022975390302933956 -5.029384911506682e-5 -3.2918505810447447e-5 -6.847704199353646e-5 -3.9282425775681626e-5 -6.699976793934056e-5 0.00015718619405779123 -0.0001834173162093404 -2.2819205612567106e-5 3.1217990384912524e-6 -2.5693163395407403e-5 0.00011796617044854111 0.00013356222197433512; -0.000156949276294281 -5.772499776556599e-5 -6.252464141340484e-5 -4.062157154888523e-5 -9.371104011423787e-5 0.00015459861401123794 8.985216640523198e-5 -3.1670710319832975e-5 -7.97268751344659e-5 -4.950557400351564e-5 0.0001864823986974604 -0.00010456759833376276 -3.355897411583118e-5 -0.0001447147390141112 -2.5347339266361406e-5 6.952951457007242e-5 9.471805750872944e-7 3.0440168978589886e-5 2.0002214790497338e-5 -0.0001261557065449628 -0.0001074825215769598 5.973277233661869e-5 -3.696248155082792e-5 -9.308013455354422e-5 -9.73088774631755e-5 9.516572549129299e-5 -2.7778227610305482e-5 -0.0001705391935106949 -9.098694341136339e-5 6.313436255678364e-5 1.292227346714092e-6 -0.00011881181331460841; 0.00017649229714008583 -0.00012320112280720055 6.92150469316698e-5 -1.087393146149689e-5 -0.00024971838418566924 -0.0001455411794100463 -0.00011446192646105914 5.66340831195928e-5 -4.9191497129485415e-5 -0.00011933334743078289 -5.046874419492767e-5 0.00011449835157547973 0.00010287906927116747 6.554340954464682e-5 -0.0001241637028957128 -8.421575610698386e-5 -1.2995657999979695e-5 5.3720644171724386e-5 -4.146821000319727e-5 -1.2194213997221364e-5 -3.930748706039843e-5 7.923969398545594e-5 -7.524598466405154e-5 -1.477069516256328e-5 2.183495500359701e-5 -0.0001541095217194892 4.061324869728844e-6 0.0001858177610469629 1.766751021788956e-5 -3.3116851301728295e-5 7.94484412094197e-5 8.324575827311977e-6; -9.930144766017807e-5 -3.660254911424946e-5 -4.330311055464956e-6 -0.00016860947685689762 -1.0450085353438215e-5 3.22597476430423e-5 9.67692900691332e-5 -9.350111245781622e-5 -7.286969080938015e-5 -2.294401234153638e-5 0.00017191708288250078 3.6595050589894556e-5 7.266534488975996e-5 4.569097258594515e-5 -0.0001184848227761719 9.991972150841431e-5 -1.4906687608669787e-5 6.190020907764422e-5 -0.00013153821087813787 8.52184890778373e-5 4.530381888127786e-5 5.9867855855050006e-5 -0.00010893516478732182 9.219913012442402e-5 4.403047900509146e-5 0.00023545876287484245 5.726905841821974e-5 7.654043987951985e-5 -1.8885006609121002e-5 5.9027155332518924e-5 -1.0453755816105676e-6 7.440913727134944e-5; 0.00012220782003276504 -0.0001456500402753742 0.00013719691845034293 -7.085265030473785e-5 6.378861792762336e-5 -4.757541199813272e-6 -0.0002511561987092288 -2.0882039160657082e-5 -0.00011427037031245903 -1.2587844020179712e-5 1.3240455694688368e-7 0.00010995396508513623 8.840422316263298e-6 -6.034609111231186e-5 5.674267152352613e-5 -1.0813437137418856e-5 -0.00019796734783887368 -5.181317180250937e-5 -0.00011202902612132203 5.501200860745019e-5 1.693845763947149e-5 0.00011361176360076727 -0.00014973214353523716 -4.165635334067259e-5 3.7615517732030334e-5 0.0001446781599326488 -9.708521128056516e-6 8.452129907817878e-5 -8.556374165784131e-5 5.1907495736651825e-5 0.00010573764951112836 2.217424947156151e-5; 0.00019297272045331625 1.99028399783744e-5 -8.387365416457719e-5 -0.00010467550056821009 0.00014412620474285031 9.715055040932191e-6 -1.540726226995714e-5 -0.00026089067434260195 -0.00013323076747657174 0.00031607035451269335 2.592181225007835e-5 4.7927192992563076e-5 4.634234754302159e-5 0.00015741804052776194 6.91690638468407e-5 -0.00013063735424840352 -7.753076535483303e-5 -3.6327639480090334e-5 -2.3250043050335637e-6 -0.00010651276717729021 -1.4967445184099577e-5 0.00022476031697987213 -6.339363064217748e-5 -8.615965287402497e-6 -7.779818590098424e-5 0.0001688927350023622 4.738844834899331e-5 -4.706260189597528e-5 5.6662049229988874e-5 3.650094911971689e-5 -6.376071725573316e-5 -0.00017682703948413794], bias = [1.0465272710886513e-9, -8.648627459016243e-10, 6.177044110112219e-10, 7.824372255188783e-10, 1.7317270881859367e-9, -1.2934086645824524e-9, 2.7733778086882753e-9, 8.560409022594865e-10, -2.5813029648001937e-9, -1.8678983153337556e-9, -4.2039080343466e-9, 1.8453968179115852e-9, -2.941258078854813e-10, 2.5467747192429376e-9, 1.7232297175041265e-9, -1.490225543629819e-9, 4.2926819752412854e-10, 2.0023164255827085e-9, 4.331211172852868e-10, 1.3306933065201586e-9, 2.4276606609381394e-9, -7.638235350095781e-9, 4.073938968200113e-10, -1.2449764782599193e-9, 5.939290348373738e-10, -9.201106115737473e-11, 3.026882406138915e-9, -4.091140537189805e-9, -1.7272046346677814e-9, 2.3279814241980605e-9, -8.811810570019317e-10, 6.602767725712663e-10]), layer_4 = (weight = [-0.0005663501057280828 -0.0007457706787273339 -0.0005188263150225607 -0.0007433366158507457 -0.000741371093682523 -0.0009341307863592964 -0.0006739418480388636 -0.0006510175756278891 -0.0006924233644782974 -0.0006901950009094447 -0.0006948877342569193 -0.0007390490542641263 -0.0007627325979109017 -0.0006246767224872426 -0.0008936695369936758 -0.0007963681804177119 -0.0009094797731055124 -0.0008643176313060718 -0.0006595119649590733 -0.0006467480697782629 -0.0005392485277672071 -0.0006863854780232547 -0.0006155302495678297 -0.0006294016237855242 -0.0006313370964353866 -0.0006397267698885347 -0.0007064080736981126 -0.0006159757419482755 -0.0007566116275041469 -0.0005476718971992025 -0.0005593494619243552 -0.0007092505880843767; 0.0002467457403708252 0.00036611309451303807 0.00026042923785267095 8.46224480049332e-5 0.00018900698291573947 0.00020850816730988469 0.00023331246739341193 7.119816009834347e-5 0.00035330663084348255 9.662076151042844e-5 0.00019913373072917708 0.00020737141920577058 0.00024241835948735202 0.00021394856053768115 0.00014686139607512322 0.0002630157962662496 0.0003537878877022067 0.00032128910705667737 0.00018387553646330725 0.0002595327623632761 5.521782899877448e-5 0.0003574204470791378 0.0002247952257293762 0.00030999700490125356 0.00014623688125948553 0.0003809631976549813 0.00018172789235517209 0.0003616560386596621 0.00016734620793762572 0.00014477961630969327 0.000357276022108114 0.0002972663048220597], bias = [-0.000705533055481162, 0.00020956819356115217]))Visualizing the Results
Let us now plot the loss over time
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")
lines!(ax, losses; linewidth=4, alpha=0.75)
scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)
fig
endFinally let us visualize the results
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(compute_waveform(
dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)
l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
s3 = scatter!(ax, tsteps, waveform_nn_trained; marker=:circle,
alpha=0.5, strokewidth=2, markersize=12)
axislegend(ax, [[l1, s1], [l2, s2], [l3, s3]],
["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
position=:lb)
fig
endAppendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 128 × AMD EPYC 7502 32-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
JULIA_CPU_THREADS = 16
JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 16
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = LiterateThis page was generated using Literate.jl.