Training a Neural ODE to Model Gravitational Waveforms
This code is adapted from Astroinformatics/ScientificMachineLearning
The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl
Package Imports
using Lux, ComponentArrays, LineSearches, OrdinaryDiffEqLowOrderRK, Optimization,
OptimizationOptimJL, Printf, Random, SciMLSensitivity
using CairoMakie
Define some Utility Functions
Tip
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector
function one2two(path, m₁, m₂)
M = m₁ + m₂
r₁ = m₂ / M .* path
r₂ = -m₁ / M .* path
return r₁, r₂
end
one2two (generic function with 1 method)
Next we define a function to perform the change of variables:
@views function soln2orbit(soln, model_params=nothing)
@assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
if size(soln, 1) == 2
χ = soln[1, :]
ϕ = soln[2, :]
@assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2"
p, M, e = model_params
else
χ = soln[1, :]
ϕ = soln[2, :]
p = soln[3, :]
e = soln[4, :]
end
r = p ./ (1 .+ e .* cos.(χ))
x = r .* cos.(ϕ)
y = r .* sin.(ϕ)
orbit = vcat(x', y')
return orbit
end
soln2orbit (generic function with 2 methods)
This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d_dt(v::AbstractVector, dt)
a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
b = (v[3:end] .- v[1:(end - 2)]) / 2
c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
return [a; b; c] / dt
end
d_dt (generic function with 1 method)
This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d2_dt2(v::AbstractVector, dt)
a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)
Now we define a function to compute the trace-free moment tensor from the orbit
function orbit2tensor(orbit, component, mass=1)
x = orbit[1, :]
y = orbit[2, :]
Ixx = x .^ 2
Iyy = y .^ 2
Ixy = x .* y
trace = Ixx .+ Iyy
if component[1] == 1 && component[2] == 1
tmp = Ixx .- trace ./ 3
elseif component[1] == 2 && component[2] == 2
tmp = Iyy .- trace ./ 3
else
tmp = Ixy
end
return mass .* tmp
end
function h_22_quadrupole_components(dt, orbit, component, mass=1)
mtensor = orbit2tensor(orbit, component, mass)
mtensor_ddot = d2_dt2(mtensor, dt)
return 2 * mtensor_ddot
end
function h_22_quadrupole(dt, orbit, mass=1)
h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
return h11, h12, h22
end
function h_22_strain_one_body(dt::T, orbit) where {T}
h11, h12, h22 = h_22_quadrupole(dt, orbit)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
h11 = h11_1 + h11_2
h12 = h12_1 + h12_2
h22 = h22_1 + h22_2
return h11, h12, h22
end
function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
# compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2
@assert abs(mass1 + mass2 - 1.0)<1e-12 "Masses do not sum to unity"
h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
@assert mass_ratio≤1 "mass_ratio must be <= 1"
@assert mass_ratio≥0 "mass_ratio must be non-negative"
orbit = soln2orbit(soln, model_params)
if mass_ratio > 0
m₂ = inv(T(1) + mass_ratio)
m₁ = mass_ratio * m₂
orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
return waveform
end
compute_waveform (generic function with 2 methods)
Simulating the True Model
RelativisticOrbitModel
defines system of odes which describes motion of point like particle in schwarzschild background, uses
where,
function RelativisticOrbitModel(u, (p, M, e), t)
χ, ϕ = u
numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
denom = sqrt((p - 2)^2 - 4 * e^2)
χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = numer / (M * (p^(3 / 2)) * denom)
return [χ̇, ϕ̇]
end
mass_ratio = 0.0 # test particle
u0 = Float64[π, 0.0] # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4) # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize) # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e
Let's simulate the true model and plot the results using OrdinaryDiffEq.jl
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)
axislegend(ax, [[l, s]], ["Waveform Data"])
fig
end
Defiing a Neural Network Model
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model
that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but incase you do use them, make sure to mark them as const
.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl
,
const nn = Chain(Base.Fix1(fast_activation, cos),
Dense(1 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32),
Dense(32 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32),
Dense(32 => 2; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32))
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-1.8344539f-5; 1.35544615f-5; 0.00023136898; -2.8087816f-6; -0.00013135397; 6.065132f-5; -2.71867f-6; -2.2196997f-5; -0.00012417452; -8.989728f-5; 5.970759f-5; 8.010392f-5; -0.00015573562; 8.0326945f-6; -5.5042838f-5; 5.0323877f-5; -0.00022916342; -1.18363405f-5; -7.2277864f-5; -0.00014578279; -9.2252645f-5; 9.920721f-5; 4.031525f-5; 0.00014180371; -2.45505f-5; -1.0413913f-5; -3.7547867f-5; -8.592635f-5; 0.000115984716; 3.421788f-5; -6.621697f-5; -9.193076f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[8.390556f-5 1.7640312f-5 8.004551f-5 -1.3293769f-5 4.2423413f-5 9.002561f-5 -8.726852f-5 -8.045185f-5 -0.000101575984 -9.6602074f-5 4.388522f-5 -1.4435316f-5 -7.123656f-5 -1.8593948f-5 -7.816393f-5 5.7547983f-5 0.000121234996 -0.00011111917 -5.265757f-5 -2.9076342f-5 3.6150457f-5 -6.7946865f-5 -7.3112744f-5 -0.0002037348 1.1003254f-5 -3.5596216f-7 9.326856f-5 -1.6159222f-5 -0.00012791745 -2.1953965f-5 -2.0230207f-5 -0.00012478871; -2.0384849f-5 0.000185472 6.9725604f-5 5.239534f-5 0.0002232958 -4.77147f-7 -5.8521335f-5 -6.989044f-5 7.4953314f-5 -7.325985f-5 -7.490285f-5 -7.889888f-5 -8.2580904f-5 -2.5151647f-5 9.575476f-5 8.885926f-5 -5.3387565f-5 -0.00027686064 1.883073f-5 -2.6812968f-5 0.00023380641 -8.7870896f-5 -2.1557764f-5 2.0252612f-5 -0.00023017132 0.00012266063 -0.0001362666 -0.0001025626 2.0433954f-5 7.811506f-5 -8.914855f-5 -0.000108874345; 1.0070617f-5 -1.6552593f-5 -0.0001407587 2.1846965f-5 -3.3476335f-6 -5.4902033f-5 -0.00018565956 1.6667705f-5 -0.00011339857 0.00014792722 1.5395464f-5 0.000140718 1.870823f-5 -0.00017195175 5.5769047f-5 -7.0043396f-5 -4.4462064f-5 0.000114674476 -4.867829f-5 0.00016543792 -0.00023586424 -4.7164813f-5 -0.00015178781 5.437404f-5 -4.615302f-5 0.00012097548 5.1872797f-5 -4.07166f-5 -0.0002228158 -8.297954f-5 -0.000103180464 9.419802f-5; 5.332961f-5 1.6070708f-5 -9.439099f-5 1.25593315f-5 -5.0150273f-5 -0.00011052204 8.702327f-5 2.1018668f-5 -4.8398462f-5 -4.745263f-6 -3.525184f-5 -0.00014377722 -5.8309247f-6 -0.00014616884 0.00016812743 -1.3368568f-5 4.0645973f-5 7.6899174f-5 6.486938f-5 2.2886523f-5 -7.2055045f-5 -0.00016280312 -0.00013863985 -4.2156626f-5 3.246711f-5 -5.839895f-5 -8.174351f-5 0.0001419509 3.0870597f-5 -0.00011077372 2.798114f-5 -8.845062f-5; -2.3264396f-5 0.00010289389 -4.954925f-5 -5.010119f-5 -2.5439425f-5 9.541054f-5 3.720457f-6 -2.3228382f-5 1.4457631f-5 9.810211f-5 -6.8660236f-5 -0.00022082058 6.614298f-5 -8.414826f-5 0.000108611384 4.14159f-5 8.202914f-5 7.337767f-5 -0.00011661759 0.00018273418 8.858553f-5 7.55119f-5 -6.775059f-5 0.00012376804 -9.2328635f-5 5.496981f-5 9.912566f-5 -0.00014847619 1.3079639f-5 -8.7992754f-5 -0.00013272803 -3.9993516f-5; -0.00013897746 -3.0155674f-5 -0.00022450555 4.362609f-6 0.00010630536 1.8045494f-5 0.00014599868 -0.00012326564 -0.000116558724 -8.3732135f-5 0.00010262302 1.2512724f-5 -5.199991f-5 0.00012526348 0.000107509106 -5.301138f-5 6.9451286f-5 9.835383f-5 0.00011319431 -2.4068302f-6 -3.9676794f-5 -3.896523f-5 -6.7420275f-5 9.515912f-5 -1.1505986f-5 0.000108727116 -2.8098044f-5 -0.00010252405 -1.4673648f-5 -8.868844f-6 -1.632383f-5 0.00022411066; 7.148544f-5 3.9464445f-5 0.00011938321 7.2405484f-5 0.000109655695 5.3021216f-5 1.1144876f-5 -2.5351346f-5 0.0001602143 0.00010032 0.00016363413 -3.3852437f-5 -9.036062f-5 -2.8673718f-5 8.684923f-5 5.9065453f-5 -0.00020033673 0.00010647708 1.6096152f-5 8.416178f-7 9.49033f-5 -0.00020317179 -8.007511f-5 -3.1025043f-5 -1.0339425f-6 -2.2548975f-5 0.000101116195 -1.6690201f-5 0.0001298293 -8.012048f-5 0.00019879898 -8.705337f-5; -3.4238914f-5 3.1139818f-5 -6.409332f-5 0.00012713329 -0.00017495049 -0.00015048916 -2.3166884f-5 6.0234157f-5 -8.7415225f-5 0.00016615517 -1.6980068f-5 -0.00011115777 -2.9920378f-5 -6.168327f-5 7.1898954f-5 -0.00015619797 -4.485504f-5 5.9469476f-5 -0.00010282306 8.836791f-6 0.000105761734 7.1593306f-5 6.443806f-7 9.182904f-5 0.00012193658 8.320485f-5 -4.912992f-5 0.00020187673 -5.477537f-5 0.00015819768 0.00014537279 -3.4915865f-5; 3.1687938f-5 -7.945534f-5 -8.5847205f-5 -2.0262725f-5 0.00010463885 0.00015181182 -3.0324112f-5 -3.2791337f-5 -7.5211065f-5 -6.853266f-5 -4.9951428f-5 -0.0001360099 0.00015305265 -0.00014691571 -0.00014609905 -0.000101132166 2.4363f-5 2.9928831f-5 0.00019767619 -0.00014222048 2.4173245f-5 3.787323f-5 0.00012716062 -0.00022765445 0.00013160569 5.3255597f-5 -3.422247f-5 1.9223893f-5 4.754344f-5 -1.7426613f-5 3.2066226f-5 1.3286504f-5; -0.00032691148 -5.8981836f-6 6.888751f-5 -2.2982436f-5 -2.4064126f-5 4.035525f-5 -5.6602414f-5 0.0001683377 4.2306896f-5 4.6980975f-5 0.00015411637 -3.2277738f-5 6.1952276f-5 1.8741419f-5 9.059188f-5 -0.00012710888 7.4715725f-5 -5.5439166f-5 -6.593062f-5 6.868757f-5 0.000100793055 -4.146959f-5 8.685888f-5 -0.00012492925 -5.1159303f-5 7.532446f-5 2.6242813f-5 -4.506473f-5 -8.077487f-5 0.0001732688 1.4488429f-5 -8.529635f-5; -4.807562f-5 -7.094981f-5 -0.00014035428 0.00018090864 0.00016708118 0.00014219405 0.00013122588 0.00017091722 -5.4529857f-5 -8.36831f-5 -5.1622605f-6 -2.7931812f-5 -4.9915918f-5 -9.228366f-5 -0.00026639522 6.886846f-5 2.962587f-5 -1.5139781f-5 3.7402242f-5 5.2810297f-5 6.390893f-5 1.4106199f-5 -7.9037054f-5 -9.371044f-5 1.7391816f-5 -1.329885f-5 -0.00014072451 8.315884f-5 7.2019684f-5 3.6479625f-5 5.009066f-5 0.00011500355; -0.000107685955 0.00017478259 -5.0166433f-5 -4.8064776f-5 4.7194026f-5 -2.4513207f-5 7.249118f-5 6.154421f-5 -0.00011399751 2.1920585f-5 -2.1785318f-5 -2.8016197f-5 -5.6880483f-5 5.5279805f-5 6.942182f-5 0.00011370061 -3.9985457f-6 -7.2414034f-5 3.6856232f-5 0.00010319237 -6.9212176f-5 1.1302556f-6 -1.073039f-7 4.7908732f-5 -9.611858f-6 0.00010344315 -9.269871f-5 0.00011482189 -9.581584f-6 -5.7158446f-5 0.00012697098 3.7471298f-6; 4.5940724f-5 1.950335f-5 -5.0505965f-5 0.00012565007 -3.1561452f-5 3.9537593f-7 1.9656149f-5 -1.2511188f-5 -8.038248f-5 1.0953372f-6 0.000167424 4.305557f-5 6.478444f-5 -5.4298587f-5 -0.00018281242 -1.0980949f-5 3.17604f-5 0.00017518309 6.214768f-5 -2.9993882f-5 6.3033294f-5 -4.9357684f-5 5.6984896f-5 -3.368544f-5 -0.00027104892 -0.00014253845 -1.5064416f-5 -3.424967f-5 7.4674676f-6 0.00014157195 4.1447467f-5 9.9208504f-5; 1.01346495f-5 -5.8032423f-5 -0.00015759438 -0.0001648138 4.1295665f-5 -0.00023703226 -2.7494596f-5 -3.2249195f-6 -3.4421162f-5 -6.476246f-5 -7.153956f-5 -0.00013795368 -3.450096f-6 4.442392f-5 3.2413394f-5 -0.00013244982 -3.8762788f-5 7.6125994f-5 -5.58425f-5 -0.00012286393 -1.0835129f-5 8.741819f-5 0.00010755454 -1.7017355f-5 1.2068588f-5 -0.00015566632 -8.5760774f-5 0.00010725287 -0.0001829899 -3.2242166f-5 -6.2268446f-5 6.835379f-5; -7.5149226f-5 -9.1479116f-5 -0.00019874288 8.1565595f-5 0.000243474 6.7872716f-5 -6.4403657f-6 4.2025556f-5 -0.00011346086 -6.721642f-5 -0.00016819894 -0.00016923886 -5.5581604f-5 7.5089294f-5 3.0082336f-5 0.00015014602 -0.00021502656 2.5278687f-5 -7.8928424f-5 -0.000112831854 -5.93066f-5 -5.83694f-5 4.2915362f-5 -2.660936f-7 -5.2336203f-5 -0.00022719377 6.925578f-5 -0.00010910163 8.95485f-5 -3.9010152f-5 -7.6268385f-5 -3.0962743f-5; 0.0002194758 -7.208452f-5 -7.9508885f-5 -6.747715f-5 7.265772f-5 -9.821001f-5 3.445883f-5 6.5566834f-5 9.425145f-5 -3.4615765f-5 -2.1011658f-5 -2.2389833f-5 -0.00023921474 9.060526f-5 -6.1412597f-6 -0.00010222498 -7.32285f-5 2.2943135f-5 7.420135f-5 -5.4772547f-5 0.00016485534 -2.7939877f-5 8.524147f-6 -9.171936f-6 8.537404f-5 4.8229267f-5 7.124636f-5 -6.854591f-5 7.9665886f-5 -0.0001920346 -0.00014183119 -0.00012554553; 2.7371718f-5 7.982744f-5 0.00016103967 -0.00012817711 -1.2322609f-5 -3.9567793f-5 6.978262f-5 6.3549756f-5 0.00011608599 -3.3571887f-5 -0.00019970827 3.2809123f-5 0.000108450324 5.73983f-5 9.98838f-5 0.00014637897 -2.4289222f-5 4.928635f-6 4.8947186f-5 -0.00010789328 5.8698457f-5 0.00013357171 -5.2049286f-6 0.00024782907 -6.5923183f-7 7.0653594f-5 -1.502676f-5 -3.913821f-5 1.148293f-5 0.0001628259 -0.000104307604 4.7328755f-5; 0.00014747275 0.00018514553 9.114203f-6 -7.052031f-5 -5.1440762f-5 -7.433742f-5 2.0634276f-5 -0.00010654112 5.3227643f-5 0.0001565159 -8.719439f-5 0.00013842624 3.1250573f-5 -0.00014332008 4.982092f-5 0.00015028851 5.005304f-5 2.8736975f-5 0.00016509031 0.0001240179 -5.793455f-5 0.00024710462 2.0110623f-5 7.720424f-5 -4.5242403f-5 6.1687424f-5 8.388954f-5 -7.138958f-6 -7.991385f-6 3.5619716f-5 0.00010095493 -3.3698292f-5; 6.419906f-5 -3.7794045f-5 -3.657212f-5 6.83079f-5 6.4258114f-5 -0.0003379122 -0.0001865474 7.871551f-5 -6.9247115f-5 3.2984735f-5 -2.290459f-5 3.2870354f-5 -0.00025094755 -1.0554151f-5 6.071415f-5 -8.706581f-5 -1.6664448f-5 -8.3615145f-5 -5.7935526f-5 0.00014819374 -0.00012218053 3.339322f-5 0.0003044563 -6.573003f-5 5.657412f-5 -0.00014351268 -7.995448f-6 -9.35417f-5 -0.000104262865 -3.7184178f-5 3.1951986f-5 -2.9004454f-5; 2.8435501f-5 -9.253271f-5 -0.000105171 2.3529672f-5 -6.807084f-5 0.00013262832 8.857907f-5 0.00015282712 0.000100617144 5.0875216f-5 3.8450657f-5 -0.00014476558 0.0001371801 -0.00010847062 -0.00019695665 -2.801764f-5 -0.00018330685 1.9167212f-6 0.00018102967 -4.603206f-5 4.8315265f-5 -0.00012419242 0.000112665344 -9.427678f-5 -0.00017318413 -4.1850282f-5 -1.5637941f-5 -2.9661313f-5 -8.6039414f-5 7.814542f-5 0.00014442937 3.7359405f-5; -3.448821f-5 -2.463726f-6 -5.525295f-5 -0.00016759854 5.466311f-5 -0.00011802021 7.3083183f-6 0.00015198525 7.2908086f-5 -2.6045433f-5 1.4138735f-5 1.3171745f-5 0.00014567454 3.7154783f-5 9.6880816f-5 -0.00024666055 -7.314138f-5 2.6544465f-5 -6.797863f-5 0.00019963949 9.71645f-5 -0.00011355499 0.00014140533 4.964353f-5 -8.207372f-5 5.7460393f-6 -2.924445f-5 -8.6133405f-5 0.0002608159 -2.0535112f-5 4.373251f-5 2.7678356f-5; -0.00019443732 4.14756f-5 -5.87447f-5 -4.0285595f-5 -7.829823f-5 -2.7288072f-5 0.00023322071 -6.90993f-5 -0.00014072443 -1.2258822f-7 0.00017251167 1.0822593f-5 6.398883f-5 -8.841678f-5 -2.0400972f-5 -0.00017211496 1.7847138f-5 -3.5838224f-5 9.202452f-5 -0.00024237043 5.685325f-5 -5.8943166f-5 8.813967f-5 -0.00011776248 5.554642f-5 2.142933f-5 -7.7973755f-6 -4.0227114f-5 -0.0002917797 -3.2273034f-5 1.4484156f-5 0.00024261692; -0.00015848443 1.776195f-6 0.00021651592 8.464783f-6 5.8173304f-5 -0.000106315434 -1.2113604f-5 -0.00011232801 0.00015653716 -4.1291594f-5 -8.109294f-5 4.6783665f-5 -0.00012483791 0.000121154786 7.346889f-5 -8.299134f-5 -3.778984f-5 0.00016197938 2.2776108f-5 9.1625334f-5 -2.9627572f-5 9.204186f-5 4.5373963f-5 -9.844927f-6 2.5937932f-5 3.8213435f-5 -9.441349f-5 -0.00029615033 4.2994376f-5 5.14269f-5 -5.76677f-5 0.00012551984; 0.00020347646 4.6201167f-6 5.4567357f-5 6.3812804f-5 -0.00019927234 -8.971765f-5 3.4233843f-5 -0.00011797361 1.2376716f-5 -0.00017088112 0.0001911142 9.733727f-5 -6.685099f-5 -5.0250634f-5 -3.701064f-5 8.068481f-5 7.438458f-5 0.00013435128 -1.4485758f-5 0.00014041884 3.5863464f-5 -5.6816985f-5 5.1864685f-5 -5.5277407f-5 -0.00015057581 -3.47063f-5 4.7381534f-5 -7.3417104f-5 -9.4763585f-5 3.950096f-5 -9.1069975f-5 0.00012384832; 5.6983125f-5 7.3611714f-6 -6.04856f-6 -2.8207354f-5 -0.0001818197 0.0001707087 -2.8581382f-5 0.000114413895 -0.000102580096 -7.132804f-5 -0.00012338476 -0.0002162207 -9.9687946f-5 -0.00027286663 0.00011434125 5.067189f-5 -3.7140104f-5 2.7771398f-6 -3.7295697f-5 0.000100645055 6.754455f-5 4.8290316f-5 4.8184215f-6 -3.9779152f-5 -2.4486264f-5 0.00015479457 -6.814508f-5 0.00016340087 9.4705305f-5 4.10732f-5 7.561145f-5 5.928624f-5; -8.538268f-5 -5.5964614f-5 -9.714605f-6 0.000114119 -6.74988f-5 -0.0001523353 -2.401359f-5 0.00010491322 -2.5857222f-5 1.7097836f-5 -4.401418f-5 -3.9394727f-5 -6.878867f-5 0.00016544724 -9.163528f-5 -3.566248f-5 7.5450014f-5 -0.00012664833 -2.9016005f-6 7.746636f-5 -4.9865448f-5 3.3721208f-5 0.00012050642 -1.2610107f-6 0.00020947291 0.00013436428 -5.326146f-6 4.0695093f-5 -6.848316f-5 4.2014835f-5 -0.00011159645 9.813985f-5; 7.3154197f-6 2.1003043f-6 -3.1364492f-5 9.561203f-5 -7.812077f-5 0.0001406038 -1.6691822f-5 -0.00017133333 7.177359f-5 -0.00017341177 9.1386006f-5 1.5045061f-5 -5.5450917f-5 9.841491f-5 0.0001790816 -0.00012635032 9.535311f-6 -6.209564f-5 9.5378615f-5 5.8487036f-5 -6.429305f-5 -0.0001146476 5.3372285f-5 -4.119011f-5 -9.2129754f-5 8.567648f-6 8.682677f-5 2.4308996f-5 8.076052f-5 -3.5730678f-5 9.194403f-5 -2.591099f-5; 0.00012653311 0.00014596594 -9.9645105f-5 -0.00014527168 4.622716f-5 -0.00014647626 -5.4383407f-5 -7.5982156f-5 0.00013147047 5.6762216f-5 0.00015175954 4.9724124f-5 5.3424093f-5 -7.4911426f-5 0.00010480568 -0.00014727311 -0.00016644174 -9.763961f-5 2.2586939f-5 5.698055f-5 1.4636252f-6 -2.3871313f-5 -9.2388094f-5 4.910409f-5 2.136647f-5 3.597631f-5 8.00851f-5 -2.6809455f-6 -7.446275f-5 -6.544439f-5 0.00013548725 -0.00013220978; -0.00016425317 5.361061f-5 -9.3313945f-5 9.952655f-5 9.118005f-5 -8.685999f-5 -4.8030717f-5 2.8719312f-6 -7.57515f-6 0.00014490567 5.0670344f-5 -1.9384695f-5 5.3309017f-5 0.00015419668 9.044762f-5 -8.936926f-5 -8.1599144f-5 3.6741927f-5 -1.642957f-5 0.00018356691 0.000104943254 2.9198793f-5 -4.6904075f-5 6.4295746f-5 4.8410522f-5 7.269556f-5 0.00015710558 -7.6455006f-5 -0.00014097808 -0.0001063957 -0.000116628406 -5.5509172f-5; -0.00012230688 -8.8608744f-5 -5.513048f-5 -1.0766721f-5 9.2510025f-5 7.818685f-6 -0.00020047394 4.1064122f-5 6.8537214f-5 6.872923f-5 -6.1089755f-5 -7.83934f-5 1.2051972f-5 -4.3849854f-5 -8.701116f-6 8.504938f-6 2.8712111f-5 9.501201f-5 0.00017670191 6.138665f-5 7.926893f-5 -6.331493f-5 -8.021726f-5 -4.1952255f-5 -3.4194098f-5 5.6444296f-5 -0.000104928826 0.00010619675 -4.2157993f-5 -0.00011295386 0.00012560004 7.816558f-6; -8.1279104f-5 7.1737515f-5 0.000116268195 0.00018544364 -0.00020381967 -6.003825f-5 6.671529f-5 7.236614f-5 7.721068f-6 -8.125255f-6 1.178353f-6 7.057577f-5 -6.783463f-5 2.7520415f-5 -2.489534f-5 0.00010546819 7.960136f-5 7.5966935f-5 0.00018582761 -0.00010862834 7.695435f-5 -0.00010250258 4.808639f-6 4.910236f-6 -4.715603f-5 -3.9323622f-5 -3.4380917f-6 0.0002810716 -9.091076f-5 -9.786974f-5 -4.602711f-5 -0.00018430213; 0.00012291249 0.00025940032 -8.296358f-5 -1.7811224f-5 -9.384394f-6 -6.413932f-5 1.1512182f-5 1.499024f-5 0.000104384366 -2.1349506f-5 0.0001042217 2.2868713f-5 -3.2120806f-5 0.00015243267 -5.791664f-5 0.00018419129 -5.116078f-6 -9.298067f-5 9.630548f-5 -3.8733f-5 -0.0001202682 9.295712f-5 7.6557364f-5 1.9552688f-5 0.00015192427 8.583486f-5 -6.3417705f-5 0.0002208377 7.5589545f-5 3.2636795f-5 2.121653f-5 -0.0001750932], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[5.6254266f-5 -1.757608f-5 -0.00021620197 -4.8526326f-5 -0.00013788677 -6.007791f-5 -8.9442f-7 -0.00011732837 -7.013296f-5 -4.2808417f-5 0.00010725215 -2.931297f-5 1.2455519f-5 6.6006236f-5 -8.4428706f-5 7.740628f-5 -5.5203505f-5 -0.00010357125 0.00017806374 -1.8403069f-5 -1.6897251f-5 -2.4636736f-5 -0.0002490198 -4.7255973f-5 8.750654f-5 4.3462445f-5 6.200407f-5 -9.909902f-5 -0.00019058448 1.2793834f-5 0.00011962756 -6.858032f-5; 8.378566f-5 -7.97114f-5 -1.6365771f-5 -8.359378f-6 5.592469f-5 -0.00012743636 -2.238498f-5 -0.00010608684 7.650973f-5 2.6785323f-5 -8.860056f-5 -0.00010964528 7.876145f-6 1.3135381f-5 -0.00016829892 -2.8552917f-5 -4.944299f-5 0.00016686285 -0.00020369819 -0.00013099945 2.0491823f-5 -2.2416045f-5 3.2533942f-5 8.056842f-5 4.030629f-5 0.000120034274 0.00030687475 3.4168344f-5 -0.00012378757 -3.7005404f-5 8.880381f-5 0.00013326439], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))
Similar to most DL frameworks, Lux defaults to using Float32
, however, in this case we need Float64
const params = ComponentArray(ps |> f64)
const nn_model = StatefulLuxLayer{true}(nn, nothing, st)
StatefulLuxLayer{true}(
Chain(
layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
layer_2 = Dense(1 => 32, cos), # 64 parameters
layer_3 = Dense(32 => 32, cos), # 1_056 parameters
layer_4 = Dense(32 => 2), # 66 parameters
),
) # Total: 1_186 parameters,
# plus 0 states.
Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses
where,
function ODE_model(u, nn_params, t)
χ, ϕ = u
p, M, e = ode_model_params
# In this example we know that `st` is am empty NamedTuple hence we can safely ignore
# it, however, in general, we should use `st` to store the state of the neural network.
y = 1 .+ nn_model([first(u)], nn_params)
numer = (1 + e * cos(χ))^2
denom = M * (p^(3 / 2))
χ̇ = (numer / denom) * y[1]
ϕ̇ = (numer / denom) * y[2]
return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)
Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)
axislegend(ax, [[l1, s1], [l2, s2]],
["Waveform Data", "Waveform Neural Net (Untrained)"]; position=:lb)
fig
end
Setting Up for Training the Neural Network
Next, we define the objective (loss) function to be minimized when training the neural differential equations.
const mseloss = MSELoss()
function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
return mseloss(pred_waveform, waveform)
end
loss (generic function with 1 method)
Warmup the loss function
loss(params)
0.0006795253562107642
Now let us define a callback function to store the loss over time
const losses = Float64[]
function callback(θ, l)
push!(losses, l)
@printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
return false
end
callback (generic function with 1 method)
Training the Neural Network
Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
optprob, BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
callback, maxiters=1000)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-1.83445390575749e-5; 1.3554461474968984e-5; 0.00023136897652836655; -2.808781573548517e-6; -0.00013135396875455375; 6.065132038197924e-5; -2.718669975362902e-6; -2.2196996724232316e-5; -0.0001241745194417971; -8.989727939471412e-5; 5.970758866166764e-5; 8.010392048152653e-5; -0.000155735615407925; 8.032694495335232e-6; -5.5042837630019185e-5; 5.032387707606982e-5; -0.00022916341549707178; -1.183634049084226e-5; -7.227786409195156e-5; -0.00014578278933170992; -9.225264511764554e-5; 9.920720913202298e-5; 4.0315251681038955e-5; 0.00014180371363157218; -2.4550499801984752e-5; -1.0413912605126551e-5; -3.754786666831233e-5; -8.592635276720837e-5; 0.00011598471610319305; 3.4217879146993704e-5; -6.621696957147267e-5; -9.193075675289229e-5;;], bias = [-5.5086991629139336e-18, 3.0767854668807254e-17, 4.482124815557636e-16, -4.05617847660039e-19, 1.637671026088228e-16, 4.757176729141586e-17, -3.1167306523050347e-18, -7.618225192025676e-19, -2.5620667272598245e-16, -8.660427539466901e-17, 3.1320897837310065e-17, 1.2117578267177638e-16, -1.4075125953833216e-16, 6.954072212953764e-18, -7.821186102477161e-18, 9.332353540512634e-17, -1.1389931229904559e-16, -3.1530488530554954e-18, -8.819011598099269e-17, -1.7126897100687712e-16, -6.01437885946302e-17, 1.4081188822041886e-16, -1.9096925755025182e-17, 2.777931959746747e-16, -5.835503415547131e-18, -2.112089215389558e-17, -7.73477866669203e-18, -7.457864165617491e-17, 2.4836584479660024e-16, 4.8369145119432824e-17, -7.892592706048722e-17, 5.3356829206410074e-17]), layer_3 = (weight = [8.390292685313405e-5 1.763768130282184e-5 8.004287854374543e-5 -1.329639992085522e-5 4.242078255581513e-5 9.00229821849855e-5 -8.72711501147389e-5 -8.045448004899276e-5 -0.0001015786149951288 -9.660470488646079e-5 4.388258792411138e-5 -1.443794673881181e-5 -7.123918869230453e-5 -1.85965785208647e-5 -7.816656245316821e-5 5.754535202819368e-5 0.00012123236563594151 -0.00011112179738169299 -5.2660201909507384e-5 -2.907897243905626e-5 3.614782661267582e-5 -6.794949542430089e-5 -7.311537439867487e-5 -0.00020373743274878625 1.1000623707750818e-5 -3.5859270773481295e-7 9.326592742532841e-5 -1.6161852109265286e-5 -0.00012792007721193615 -2.1956595738210226e-5 -2.0232837705894858e-5 -0.00012479134267826542; -2.03860391226965e-5 0.00018547081248498098 6.972441309728436e-5 5.239414798806921e-5 0.00022329461620784603 -4.783375945406299e-7 -5.852252530503519e-5 -6.989163108273822e-5 7.495212315944187e-5 -7.326103974237224e-5 -7.4904038283148e-5 -7.890006696443255e-5 -8.258209441891365e-5 -2.515283727889379e-5 9.57535693853969e-5 8.885807159485894e-5 -5.338875513159933e-5 -0.000276861826077676 1.8829538876435288e-5 -2.6814158499921495e-5 0.00023380521719473308 -8.787208667826888e-5 -2.1558954412922225e-5 2.0251421065319997e-5 -0.00023017251538449374 0.00012265943491029127 -0.00013626778405496294 -0.00010256378760432037 2.0432763394914523e-5 7.811386871280692e-5 -8.914973755936402e-5 -0.00010887553531547179; 1.006803971591538e-5 -1.6555169407584443e-5 -0.0001407612737815643 2.1844388117241782e-5 -3.350210375560634e-6 -5.4904610136361364e-5 -0.00018566213954131016 1.6665128656822785e-5 -0.00011340114839497476 0.00014792464092949263 1.539288745021853e-5 0.00014071541839854506 1.87056527798466e-5 -0.0001719543227076799 5.576647048343334e-5 -7.004597251892701e-5 -4.4464640507969206e-5 0.00011467189937007731 -4.868086513255176e-5 0.0001654353388735935 -0.00023586681460386247 -4.716738955124275e-5 -0.00015179038861191593 5.437146204401682e-5 -4.615559852340104e-5 0.00012097290049183296 5.187022064688839e-5 -4.071917761706017e-5 -0.00022281837415305582 -8.298211583991333e-5 -0.00010318304090418379 9.419544106691964e-5; 5.332745027414174e-5 1.6068546443480193e-5 -9.439315212502468e-5 1.255717020836985e-5 -5.015243399968459e-5 -0.00011052420481422098 8.702110758901202e-5 2.101650684932374e-5 -4.8400623512791776e-5 -4.747424412861361e-6 -3.5254001289951895e-5 -0.00014377937654347077 -5.833085964619548e-6 -0.00014617099836317261 0.00016812526851578457 -1.3370729306127956e-5 4.0643811786236014e-5 7.689701274654466e-5 6.48672209136467e-5 2.2884361584221088e-5 -7.205720643826534e-5 -0.00016280527810879665 -0.0001386420120469972 -4.215878684037252e-5 3.24649478498627e-5 -5.8401110749341486e-5 -8.174566787050379e-5 0.0001419487449821379 3.086843540229478e-5 -0.00011077588018809736 2.7978978010713416e-5 -8.845278384244779e-5; -2.3264074579766343e-5 0.00010289421469687922 -4.9548928027349866e-5 -5.0100867620040955e-5 -2.5439103874217026e-5 9.541086318686073e-5 3.720778513615085e-6 -2.3228060408565552e-5 1.4457952138488601e-5 9.810242908324467e-5 -6.86599146410131e-5 -0.0002208202627206464 6.614330185593289e-5 -8.414794002410517e-5 0.00010861170584584994 4.14162218386069e-5 8.202946036371283e-5 7.337799023853121e-5 -0.00011661727195123758 0.00018273449813031802 8.858585487368865e-5 7.551222505762577e-5 -6.775026714413005e-5 0.00012376836229131778 -9.232831372164689e-5 5.497013034082453e-5 9.912598169657825e-5 -0.00014847586727219275 1.3079960344928009e-5 -8.799243233616037e-5 -0.00013272771185247767 -3.9993194613909584e-5; -0.00013897681854365993 -3.015503714784177e-5 -0.00022450491449497477 4.363245825453112e-6 0.00010630599355856873 1.8046130226075505e-5 0.00014599931821060707 -0.00012326499925627047 -0.00011655808701113775 -8.373149828919037e-5 0.00010262365323790613 1.2513361089012803e-5 -5.1999271800081934e-5 0.00012526411957872438 0.00010750974253817738 -5.3010742685791944e-5 6.945192219056376e-5 9.835446327991191e-5 0.00011319495026321382 -2.4061934983129027e-6 -3.9676157011918926e-5 -3.896459382310372e-5 -6.741963805331541e-5 9.515975763854495e-5 -1.1505348975396692e-5 0.00010872775239470755 -2.8097407248487373e-5 -0.00010252341233360394 -1.467301170129965e-5 -8.868207241966483e-6 -1.6323192521731663e-5 0.0002241112925002461; 7.148827780309771e-5 3.946728310819434e-5 0.00011938604502041272 7.240832264341149e-5 0.00010965853352369746 5.302405469972967e-5 1.11477144802731e-5 -2.5348508011860143e-5 0.00016021714026844303 0.00010032283962043854 0.00016363697131435722 -3.384959864508287e-5 -9.035778139681267e-5 -2.8670879424510432e-5 8.685206620992547e-5 5.9068290915260697e-5 -0.00020033388801117885 0.00010647991502463542 1.6098989994299575e-5 8.444560595904405e-7 9.490613655923749e-5 -0.00020316895034359076 -8.007226943465363e-5 -3.10222051719695e-5 -1.0311041986037178e-6 -2.254613655696897e-5 0.00010111903311024956 -1.668736268204186e-5 0.0001298321446430905 -8.011764230633527e-5 0.00019880181998285725 -8.705053124305905e-5; -3.4237829704371076e-5 3.114090225629256e-5 -6.409223703466632e-5 0.00012713437184195723 -0.0001749494039084103 -0.00015048807096697768 -2.316579945706544e-5 6.023524122631844e-5 -8.741414046178433e-5 0.00016615624929747943 -1.697898357612919e-5 -0.0001111566888264614 -2.991929420184517e-5 -6.168218521431337e-5 7.190003828348237e-5 -0.0001561968891744722 -4.485395754253733e-5 5.9470559908940736e-5 -0.00010282197751600726 8.83787482072998e-6 0.00010576281795713504 7.159438985602335e-5 6.454648134685819e-7 9.183012310151914e-5 0.00012193766463517873 8.320593781739223e-5 -4.912883729936852e-5 0.0002018778105868721 -5.4774285554026965e-5 0.00015819876728549314 0.00014537387335642772 -3.491478116281877e-5; 3.168718551816468e-5 -7.94560938368547e-5 -8.584795711728314e-5 -2.0263477446208297e-5 0.0001046380975173833 0.00015181107021700724 -3.0324864447977914e-5 -3.279208892425469e-5 -7.521181675772693e-5 -6.853341171452536e-5 -4.9952180298789054e-5 -0.00013601064990442048 0.00015305189747661993 -0.0001469164663569997 -0.00014609979832246782 -0.00010113291779933392 2.4362247976324826e-5 2.9928079064772906e-5 0.00019767543283589356 -0.00014222123270090108 2.417249282072549e-5 3.7872477085984345e-5 0.00012715986388567917 -0.0002276552068645081 0.00013160493556708158 5.3254844650580044e-5 -3.4223222493239857e-5 1.9223140887988168e-5 4.754268677854352e-5 -1.742736555687082e-5 3.20654734684635e-5 1.3285752110022747e-5; -0.00032691089311678 -5.897594499889778e-6 6.88880970951251e-5 -2.298184720397017e-5 -2.4063536537633085e-5 4.035584040271957e-5 -5.660182453406064e-5 0.00016833829410776462 4.2307485065484964e-5 4.698156387481662e-5 0.00015411695473764 -3.227714862854366e-5 6.19528653000619e-5 1.8742008043285233e-5 9.059246783314193e-5 -0.0001271082948992781 7.471631376706895e-5 -5.543857717251428e-5 -6.593002783896522e-5 6.868816105585118e-5 0.000100793644170573 -4.146900048775809e-5 8.685947229556688e-5 -0.000124928665380631 -5.11587133846994e-5 7.532504948491206e-5 2.6243402109196857e-5 -4.506414209439582e-5 -8.077428376831394e-5 0.00017326938520587936 1.4489018004917415e-5 -8.529576028062214e-5; -4.8074698288700755e-5 -7.094889021595865e-5 -0.00014035336223084129 0.00018090956130413423 0.00016708209862349122 0.00014219497664248545 0.0001312268052331553 0.00017091814354876128 -5.452893471579445e-5 -8.36821798386627e-5 -5.161338113112389e-6 -2.793088949170073e-5 -4.991499533412688e-5 -9.228274008842106e-5 -0.0002663942957681232 6.886938192997135e-5 2.962679193818389e-5 -1.513885833469375e-5 3.74031644512333e-5 5.2811219784183494e-5 6.390984911354089e-5 1.4107121596450662e-5 -7.903613165603919e-5 -9.37095189967803e-5 1.739273850921605e-5 -1.3297927371467611e-5 -0.00014072359205806584 8.315976286189493e-5 7.202060644862854e-5 3.6480547559904404e-5 5.0091583243328585e-5 0.00011500447565572168; -0.00010768456667013316 0.00017478397677602402 -5.016504399996638e-5 -4.806338728106415e-5 4.719541431237331e-5 -2.451181825184263e-5 7.249256924675941e-5 6.154559997791568e-5 -0.0001139961179304362 2.1921973542853137e-5 -2.1783928956058756e-5 -2.8014808036309043e-5 -5.6879093959508436e-5 5.528119328498072e-5 6.94232065272148e-5 0.00011370199973877572 -3.99715707849935e-6 -7.241264500938198e-5 3.68576209568039e-5 0.00010319375616214376 -6.921078710168818e-5 1.131644184319141e-6 -1.0591526792672335e-7 4.7910120714876836e-5 -9.610469702085024e-6 0.00010344453659324135 -9.269732362175259e-5 0.00011482327573882249 -9.580195351952904e-6 -5.715705736823268e-5 0.00012697236417276753 3.7485183967552735e-6; 4.59413218909571e-5 1.9503948000881293e-5 -5.050536709450532e-5 0.0001256506671911734 -3.156085439104227e-5 3.9597403156831504e-7 1.9656746748757983e-5 -1.2510590191134373e-5 -8.038188242198504e-5 1.0959353080631998e-6 0.0001674246028546865 4.305616822830972e-5 6.478504075357912e-5 -5.429798910471847e-5 -0.0001828118241842214 -1.0980350803486591e-5 3.176099709389405e-5 0.00017518368405445238 6.214827736994846e-5 -2.99932838838659e-5 6.3033892378831e-5 -4.935708637776281e-5 5.698549431554707e-5 -3.3684841023838826e-5 -0.00027104832259782836 -0.00014253785628670133 -1.5063818305400501e-5 -3.424907334692225e-5 7.468065736706749e-6 0.00014157254337488414 4.144806518025375e-5 9.920910234951015e-5; 1.013006603411676e-5 -5.80370066066736e-5 -0.0001575989586901986 -0.0001648183821137995 4.129108164732113e-5 -0.00023703684528786147 -2.7499179491633595e-5 -3.229503004467605e-6 -3.442574562116099e-5 -6.476704548307428e-5 -7.154414161536317e-5 -0.0001379582678103725 -3.4546795680496524e-6 4.441933596776802e-5 3.240881087063325e-5 -0.00013245439877695964 -3.876737137512873e-5 7.612141088852673e-5 -5.5847085245970374e-5 -0.00012286851379514303 -1.0839712199719117e-5 8.741360979426249e-5 0.00010754995345873996 -1.702193877037421e-5 1.2064004654030689e-5 -0.0001556709026820957 -8.576535729421098e-5 0.00010724828498010687 -0.00018299448093360158 -3.2246749110844166e-5 -6.227302911187806e-5 6.834920386123791e-5; -7.515308930965982e-5 -9.148297936850156e-5 -0.00019874674722740178 8.156173175006222e-5 0.00024347013863076808 6.786885245741857e-5 -6.44422885603455e-6 4.202169333506485e-5 -0.00011346472416265746 -6.722028006993103e-5 -0.0001682028010320707 -0.0001692427245500319 -5.5585467234073945e-5 7.508543098150808e-5 3.0078472664849462e-5 0.00015014215790938023 -0.0002150304276794369 2.5274823602322215e-5 -7.893228717807908e-5 -0.000112835717626904 -5.9310462856272244e-5 -5.8373264945881645e-5 4.291149929553522e-5 -2.6995672599867016e-7 -5.2340066358063986e-5 -0.00022719763324790572 6.925191741242298e-5 -0.00010910549422919707 8.954463595797183e-5 -3.901401538247048e-5 -7.627224800603797e-5 -3.0966606172340074e-5; 0.0002194746960121823 -7.20856221341844e-5 -7.950998745580528e-5 -6.747825294223986e-5 7.265661794736293e-5 -9.821111210112703e-5 3.445772769555075e-5 6.556573159104789e-5 9.425035047388814e-5 -3.461686695300127e-5 -2.1012759800215855e-5 -2.2390935311705392e-5 -0.00023921584555136474 9.060415711064451e-5 -6.142361809908508e-6 -0.00010222608010008284 -7.322959912197093e-5 2.2942033199794634e-5 7.420024873499571e-5 -5.47736487824245e-5 0.0001648542422735777 -2.7940979391131773e-5 8.523045169565926e-6 -9.173037808063717e-6 8.537294003690241e-5 4.8228165058387456e-5 7.124525679318841e-5 -6.854701108022705e-5 7.966478394155496e-5 -0.00019203569875753817 -0.00014183228971160233 -0.00012554663019909596; 2.7375398745753238e-5 7.983112230791487e-5 0.0001610433530330648 -0.0001281734300974125 -1.2318928169463193e-5 -3.9564111615658654e-5 6.978629902365416e-5 6.355343724495471e-5 0.00011608967109708859 -3.356820594858618e-5 -0.0001997045917633332 3.281280372534837e-5 0.00010845400469036457 5.7401982518022644e-5 9.988748137034841e-5 0.00014638264768651179 -2.428554104345366e-5 4.932315947335584e-6 4.8950866799672224e-5 -0.00010788959549615117 5.870213793600617e-5 0.0001335753906786955 -5.201247644473812e-6 0.00024783275078874267 -6.555508903512773e-7 7.065727480581746e-5 -1.5023078744017044e-5 -3.913452816418528e-5 1.1486610888771714e-5 0.00016282958607549434 -0.00010430392354832758 4.733243552977865e-5; 0.0001474770753435569 0.00018514985929017554 9.118531279959586e-6 -7.051598443409738e-5 -5.14364335953255e-5 -7.433309004169787e-5 2.063860408328631e-5 -0.00010653679130824659 5.3231971078834626e-5 0.0001565202320949774 -8.719006366788292e-5 0.00013843057165161416 3.1254901087414126e-5 -0.00014331575296579245 4.982524767059094e-5 0.0001502928418364123 5.005736890830698e-5 2.8741303295317563e-5 0.00016509464255307682 0.00012402222561886926 -5.7930222127992936e-5 0.0002471089457418413 2.011495159480182e-5 7.720856881006246e-5 -4.5238074407609376e-5 6.169175246271914e-5 8.389387146111069e-5 -7.134629713048538e-6 -7.987056805525312e-6 3.5624044513166803e-5 0.00010095925897133538 -3.369396365940218e-5; 6.419598335861614e-5 -3.7797124460382393e-5 -3.657520013860369e-5 6.83048184867518e-5 6.425503503062959e-5 -0.0003379152857093534 -0.0001865504835593013 7.8712428293614e-5 -6.925019428369573e-5 3.298165589451079e-5 -2.2907670282332628e-5 3.2867274202835367e-5 -0.00025095062468668487 -1.0557230635356291e-5 6.071106887185668e-5 -8.706889078956425e-5 -1.666752708900297e-5 -8.361822423529445e-5 -5.793860495094272e-5 0.0001481906566908363 -0.0001221836081100974 3.339013906890921e-5 0.00030445322605311453 -6.573311278831668e-5 5.657104171342629e-5 -0.00014351575540848643 -7.99852736298434e-6 -9.354477682852963e-5 -0.00010426594402876863 -3.718725733113938e-5 3.1948906470749936e-5 -2.9007533012807592e-5; 2.8434851199243403e-5 -9.253336108699986e-5 -0.0001051716485311219 2.352902222596494e-5 -6.807148972470497e-5 0.00013262767381788692 8.857842251713771e-5 0.0001528264743025361 0.00010061649438980743 5.087456659352961e-5 3.845000676633319e-5 -0.00014476623116238454 0.00013717945469366352 -0.00010847126620533491 -0.0001969573008612778 -2.801828893311816e-5 -0.00018330750251365266 1.9160713942696165e-6 0.00018102901733768666 -4.6032708698791515e-5 4.831461549442824e-5 -0.000124193068106569 0.00011266469439546998 -9.427742995498823e-5 -0.0001731847828271852 -4.185093209557781e-5 -1.5638590841582567e-5 -2.9661962318959836e-5 -8.604006368145609e-5 7.814477205797509e-5 0.00014442872409531392 3.7358755005380166e-5; -3.448706253194366e-5 -2.4625785606167557e-6 -5.525180155175982e-5 -0.00016759739288473607 5.4664256584404e-5 -0.00011801906135743726 7.309465889457047e-6 0.00015198639970393727 7.290923325698144e-5 -2.6044285061028188e-5 1.4139882967618727e-5 1.3172892739282749e-5 0.00014567568517875773 3.715593060250289e-5 9.68819638167486e-5 -0.00024665940351197685 -7.314023447336834e-5 2.654561261650998e-5 -6.797748416075782e-5 0.00019964063334398935 9.716564612816597e-5 -0.00011355384254940327 0.00014140647339289838 4.96446789277632e-5 -8.207257111435882e-5 5.747186822454771e-6 -2.9243301707274904e-5 -8.613225714881354e-5 0.00026081704503511403 -2.053396407219262e-5 4.373365822019721e-5 2.7679503308069427e-5; -0.00019443947699673188 4.1473444201446114e-5 -5.87468547169462e-5 -4.028775024980042e-5 -7.830038690060764e-5 -2.729022695037946e-5 0.0002332185523695525 -6.910145189666823e-5 -0.00014072658255516134 -1.2474359069509312e-7 0.00017250951909173796 1.0820437767882984e-5 6.398667587148128e-5 -8.841893746338517e-5 -2.0403127433918314e-5 -0.00017211711552280487 1.7844982852192974e-5 -3.5840379365791625e-5 9.202236664660904e-5 -0.0002423725835402134 5.6851093307537256e-5 -5.894532101281895e-5 8.813751459766556e-5 -0.0001177646353031155 5.554426583640583e-5 2.1427175073230554e-5 -7.799530846195861e-6 -4.022926974047125e-5 -0.00029178185091639596 -3.227518923867148e-5 1.4482000682084944e-5 0.0002426147676882606; -0.00015848396754281836 1.7766559454658034e-6 0.000216516384962768 8.465243691350264e-6 5.8173764837733105e-5 -0.00010631497309932781 -1.2113142592536822e-5 -0.00011232754798529917 0.00015653762499877712 -4.129113324725432e-5 -8.109248040400742e-5 4.678412630764431e-5 -0.00012483745115733378 0.0001211552470254874 7.346934841789017e-5 -8.299087960895431e-5 -3.7789378556600895e-5 0.00016197983756737546 2.2776569376761634e-5 9.162579487619358e-5 -2.962711107957546e-5 9.204232162173449e-5 4.5374424071812215e-5 -9.84446626662822e-6 2.5938393146677246e-5 3.8213895887199246e-5 -9.441302711766477e-5 -0.0002961498731832801 4.2994836686006965e-5 5.1427362332648816e-5 -5.7667240188860604e-5 0.00012552030500100558; 0.00020347676727307097 4.6204218952085295e-6 5.456766266343652e-5 6.38131092264417e-5 -0.00019927203575297357 -8.971734532768804e-5 3.4234147823565425e-5 -0.00011797330116768303 1.2377021083931076e-5 -0.00017088081258497206 0.00019111450783294663 9.733757869890896e-5 -6.685068401656286e-5 -5.0250328423305636e-5 -3.701033383639489e-5 8.068511633520059e-5 7.438488646494705e-5 0.00013435159016718227 -1.448545244874181e-5 0.0001404191421888848 3.586376951626488e-5 -5.681668008127278e-5 5.1864990017759237e-5 -5.52771020020278e-5 -0.0001505755066327651 -3.4705994404237014e-5 4.73818396121828e-5 -7.341679920590487e-5 -9.476328014078613e-5 3.950126447217599e-5 -9.106966936177548e-5 0.00012384862893578583; 5.6983087829859874e-5 7.361134662925909e-6 -6.0485966464396346e-6 -2.8207390955220034e-5 -0.00018181974186926508 0.00017070867002577146 -2.8581418832334747e-5 0.00011441385836907285 -0.00010258013239262053 -7.132807545029927e-5 -0.00012338479459184597 -0.00021622073140359856 -9.968798289672821e-5 -0.00027286666622284244 0.00011434121520825287 5.067185307371986e-5 -3.714014050682391e-5 2.777103105028803e-6 -3.7295733222424977e-5 0.00010064501807893273 6.754451138068555e-5 4.829027935148781e-5 4.818384837119105e-6 -3.977918854114954e-5 -2.44863007032909e-5 0.00015479453546096193 -6.814511320452582e-5 0.00016340082916532273 9.470526825272117e-5 4.107316240653019e-5 7.561141651923394e-5 5.928620490113836e-5; -8.538208239989201e-5 -5.596401788571693e-5 -9.714008594911506e-6 0.00011411959676260288 -6.749820058977406e-5 -0.0001523347022179319 -2.4012994256199595e-5 0.00010491381508253665 -2.5856625509198282e-5 1.709843272473201e-5 -4.401358492251394e-5 -3.939413036497274e-5 -6.878807507966025e-5 0.00016544783975185984 -9.163468385159869e-5 -3.5661884251030914e-5 7.545061052731306e-5 -0.0001266477342163942 -2.901004247551894e-6 7.746695300519216e-5 -4.98648518281803e-5 3.372180399090952e-5 0.00012050701443294193 -1.2604143898116099e-6 0.0002094735096739806 0.0001343648760644921 -5.325549775452652e-6 4.069568931080417e-5 -6.848256489539622e-5 4.201543160669643e-5 -0.00011159585156056955 9.81404442405514e-5; 7.315850550172523e-6 2.1007351706304326e-6 -3.1364061444966994e-5 9.561245837811492e-5 -7.812034005202485e-5 0.00014060423722820416 -1.6691390786048767e-5 -0.00017133290225775807 7.177402016803515e-5 -0.00017341133775791156 9.13864366766444e-5 1.5045492315935182e-5 -5.545048611840464e-5 9.841533915003871e-5 0.00017908203312525596 -0.0001263498909365937 9.535741580273705e-6 -6.209521075380856e-5 9.537904565785118e-5 5.8487466728350125e-5 -6.42926227128693e-5 -0.00011464716795036596 5.337271590814083e-5 -4.118967831620187e-5 -9.2129323323219e-5 8.568079235361595e-6 8.682719794431793e-5 2.4309427087232544e-5 8.076094831902456e-5 -3.5730247193909234e-5 9.194445896990647e-5 -2.5910558409136796e-5; 0.00012653265643415654 0.0001459654877569308 -9.964555734518445e-5 -0.0001452721343932892 4.622670807543143e-5 -0.00014647671283206238 -5.438385865751045e-5 -7.598260795681121e-5 0.00013147001940773704 5.676176355586185e-5 0.0001517590854225467 4.9723671548001064e-5 5.3423641513968256e-5 -7.491187803430804e-5 0.00010480522818647048 -0.0001472735611580527 -0.00016644218790793337 -9.764006332012214e-5 2.2586486959645872e-5 5.698009685397141e-5 1.4631731735263808e-6 -2.387176478556068e-5 -9.238854632463604e-5 4.910363626799522e-5 2.136601782944206e-5 3.597585877536476e-5 8.008464559929143e-5 -2.6813974793713144e-6 -7.44631988320698e-5 -6.544483854418925e-5 0.00013548679912013675 -0.00013221023342089924; -0.00016425217224248437 5.3611608320281604e-5 -9.331294610828716e-5 9.952755020269453e-5 9.118105184090022e-5 -8.685898980554983e-5 -4.8029717862007384e-5 2.872930492159924e-6 -7.574150702642599e-6 0.00014490667193233032 5.067134291668568e-5 -1.9383696115328406e-5 5.331001623918704e-5 0.00015419767690393402 9.044861756767309e-5 -8.936826385986812e-5 -8.159814503241853e-5 3.674292632415217e-5 -1.6428571387295625e-5 0.00018356790927274652 0.00010494425326389246 2.919979190839395e-5 -4.6903075843218493e-5 6.429674497841462e-5 4.841152143119145e-5 7.269656155760658e-5 0.00015710657565425363 -7.645400661938802e-5 -0.0001409770823266468 -0.00010639470316898787 -0.00011662740623076328 -5.5508172652553906e-5; -0.0001223072869977014 -8.860914769294055e-5 -5.513088418929453e-5 -1.076712507320653e-5 9.25096208305089e-5 7.81828092321251e-6 -0.00020047434028648157 4.1063718330670747e-5 6.853681023521233e-5 6.872882275664967e-5 -6.109015904777023e-5 -7.839380506603282e-5 1.205156766836299e-5 -4.385025769416561e-5 -8.701519814811636e-6 8.504533783972208e-6 2.8711707076853136e-5 9.50116043752959e-5 0.0001767015089938392 6.13862464745525e-5 7.926852757399694e-5 -6.331533606108468e-5 -8.021766208522318e-5 -4.195265884455348e-5 -3.419450169140784e-5 5.644389219404812e-5 -0.00010492922976091804 0.00010619634466301429 -4.2158397460031804e-5 -0.0001129542617633969 0.0001255996355978746 7.816153615105687e-6; -8.12781320150404e-5 7.173848713641778e-5 0.00011626916672460706 0.0001854446134216691 -0.00020381869693160802 -6.0037276156845893e-5 6.671626466918533e-5 7.23671153223679e-5 7.722039995810573e-6 -8.124282644032423e-6 1.179325081965703e-6 7.057674226011676e-5 -6.783366119130551e-5 2.7521386769467778e-5 -2.4894368138530483e-5 0.00010546916452362852 7.960233484985309e-5 7.596790671009364e-5 0.00018582858025688312 -0.00010862736529312038 7.695531964195419e-5 -0.00010250160561049872 4.80961104180992e-6 4.911207966455031e-6 -4.715505859571731e-5 -3.9322650295116205e-5 -3.437119616546694e-6 0.00028107258255254033 -9.090979004618408e-5 -9.78687705669017e-5 -4.6026139202195834e-5 -0.00018430115891028665; 0.0001229162222287178 0.0002594040484680356 -8.295985064005215e-5 -1.7807491910594327e-5 -9.380661519797953e-6 -6.413558768802473e-5 1.1515914214009614e-5 1.4993972020214488e-5 0.00010438809802758781 -2.134577372747698e-5 0.00010422542944322202 2.2872445317958277e-5 -3.211707409469287e-5 0.00015243639753205067 -5.791290861200613e-5 0.00018419502119538364 -5.112345586254562e-6 -9.297693797382243e-5 9.630920938686129e-5 -3.872926792360351e-5 -0.00012026447113133777 9.296085007259648e-5 7.656109626875359e-5 1.9556420359319802e-5 0.00015192799726971657 8.583858978706699e-5 -6.34139727637625e-5 0.00022084142945710895 7.559327749073088e-5 3.264052755437016e-5 2.1220262699688276e-5 -0.0001750894665903951], bias = [-2.6305502071499187e-9, -1.1905828627179324e-9, -2.576838088114409e-9, -2.1612722184538984e-9, 3.214973720681143e-10, 6.366644181871915e-10, 2.8382548290852466e-9, 1.0842361276485243e-9, -7.522169496503006e-10, 5.891447139626611e-10, 9.22436584603265e-10, 1.3886293971729812e-9, 5.98096910371713e-10, -4.583494365274096e-9, -3.863132642099283e-9, -1.10206383654289e-9, 3.680939663041263e-9, 4.328440480589899e-9, -3.0794041661659663e-9, -6.49808818358469e-10, 1.1475442898404305e-9, -2.1553720134965717e-9, 4.6099612249646707e-10, 3.0522552389958867e-10, -3.6693098080930016e-11, 5.962625348226679e-10, 4.3087037391978566e-10, -4.5197999426656625e-10, 9.992908619458862e-10, -4.0401183386750624e-10, 9.720366055377528e-10, 3.732186791689492e-9]), layer_4 = (weight = [-0.0006108862745047021 -0.0006847167631872372 -0.0008833424982244536 -0.0007156669221972863 -0.0008050274922441746 -0.0007272186192600685 -0.0006680349298343259 -0.0007844690559370944 -0.0007372736678410596 -0.0007099491279955244 -0.0005598885511582027 -0.0006964536407049327 -0.0006546851921392669 -0.0006011339461028504 -0.0007515690314709679 -0.0005897344105224736 -0.000722343865474538 -0.0007707114465270688 -0.0004890767541183106 -0.0006855437782326744 -0.0006840379367135967 -0.0006917773342029589 -0.0009161605092181394 -0.0007143966911845198 -0.000579634182511351 -0.0006236782658754548 -0.0006051366491383553 -0.0007662397378437101 -0.0008577251766121424 -0.0006543468824973079 -0.0005475131340124602 -0.0007357206560183073; 0.00031803453709212617 0.00015453753296430016 0.0002178831055986803 0.00022588952173446456 0.0002901736329985614 0.00010681257895248892 0.0002118638891519666 0.00012816209386412692 0.0003107586660401675 0.0002610342627055164 0.00014564837739078392 0.0001246036420357842 0.0002421250847581 0.00024738413554525787 6.594988073323309e-5 0.00020569601547665254 0.00018480582553796548 0.0004011116140801918 3.0550674050906125e-5 0.00010324948420705885 0.00025474075415429624 0.00021183285527020195 0.0002667828833131131 0.0003148173597727156 0.000274555234363389 0.0003542832137814762 0.0005411236877604933 0.00026841728550967804 0.00011046136388748102 0.00019724353754288048 0.0003230527467534024 0.00036751319741848957], bias = [-0.0006671407205116781, 0.00023424894303618842]))
Visualizing the Results
Let us now plot the loss over time
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")
lines!(ax, losses; linewidth=4, alpha=0.75)
scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)
fig
end
Finally let us visualize the results
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(compute_waveform(
dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)
l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
s3 = scatter!(ax, tsteps, waveform_nn_trained; marker=:circle,
alpha=0.5, strokewidth=2, markersize=12)
axislegend(ax, [[l1, s1], [l2, s2], [l3, s3]],
["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
position=:lb)
fig
end
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.10.6
Commit 67dffc4a8ae (2024-10-28 12:23 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 128 × AMD EPYC 7502 32-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
JULIA_CPU_THREADS = 16
JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 16
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.