Training a Neural ODE to Model Gravitational Waveforms
This code is adapted from Astroinformatics/ScientificMachineLearning
The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl
Package Imports
using Lux, ComponentArrays, LineSearches, OrdinaryDiffEqLowOrderRK, Optimization,
OptimizationOptimJL, Printf, Random, SciMLSensitivity
using CairoMakie
Define some Utility Functions
Tip
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector
function one2two(path, m₁, m₂)
M = m₁ + m₂
r₁ = m₂ / M .* path
r₂ = -m₁ / M .* path
return r₁, r₂
end
one2two (generic function with 1 method)
Next we define a function to perform the change of variables:
@views function soln2orbit(soln, model_params=nothing)
@assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
if size(soln, 1) == 2
χ = soln[1, :]
ϕ = soln[2, :]
@assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2"
p, M, e = model_params
else
χ = soln[1, :]
ϕ = soln[2, :]
p = soln[3, :]
e = soln[4, :]
end
r = p ./ (1 .+ e .* cos.(χ))
x = r .* cos.(ϕ)
y = r .* sin.(ϕ)
orbit = vcat(x', y')
return orbit
end
soln2orbit (generic function with 2 methods)
This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d_dt(v::AbstractVector, dt)
a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
b = (v[3:end] .- v[1:(end - 2)]) / 2
c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
return [a; b; c] / dt
end
d_dt (generic function with 1 method)
This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d2_dt2(v::AbstractVector, dt)
a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)
Now we define a function to compute the trace-free moment tensor from the orbit
function orbit2tensor(orbit, component, mass=1)
x = orbit[1, :]
y = orbit[2, :]
Ixx = x .^ 2
Iyy = y .^ 2
Ixy = x .* y
trace = Ixx .+ Iyy
if component[1] == 1 && component[2] == 1
tmp = Ixx .- trace ./ 3
elseif component[1] == 2 && component[2] == 2
tmp = Iyy .- trace ./ 3
else
tmp = Ixy
end
return mass .* tmp
end
function h_22_quadrupole_components(dt, orbit, component, mass=1)
mtensor = orbit2tensor(orbit, component, mass)
mtensor_ddot = d2_dt2(mtensor, dt)
return 2 * mtensor_ddot
end
function h_22_quadrupole(dt, orbit, mass=1)
h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
return h11, h12, h22
end
function h_22_strain_one_body(dt::T, orbit) where {T}
h11, h12, h22 = h_22_quadrupole(dt, orbit)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
h11 = h11_1 + h11_2
h12 = h12_1 + h12_2
h22 = h22_1 + h22_2
return h11, h12, h22
end
function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
# compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2
@assert abs(mass1 + mass2 - 1.0)<1e-12 "Masses do not sum to unity"
h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
@assert mass_ratio≤1 "mass_ratio must be <= 1"
@assert mass_ratio≥0 "mass_ratio must be non-negative"
orbit = soln2orbit(soln, model_params)
if mass_ratio > 0
m₂ = inv(T(1) + mass_ratio)
m₁ = mass_ratio * m₂
orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
return waveform
end
compute_waveform (generic function with 2 methods)
Simulating the True Model
RelativisticOrbitModel
defines system of odes which describes motion of point like particle in schwarzschild background, uses
where,
function RelativisticOrbitModel(u, (p, M, e), t)
χ, ϕ = u
numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
denom = sqrt((p - 2)^2 - 4 * e^2)
χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = numer / (M * (p^(3 / 2)) * denom)
return [χ̇, ϕ̇]
end
mass_ratio = 0.0 # test particle
u0 = Float64[π, 0.0] # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4) # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize) # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e
Let's simulate the true model and plot the results using OrdinaryDiffEq.jl
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)
axislegend(ax, [[l, s]], ["Waveform Data"])
fig
end
Defiing a Neural Network Model
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model
that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but incase you do use them, make sure to mark them as const
.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl
,
const nn = Chain(Base.Fix1(fast_activation, cos),
Dense(1 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32),
Dense(32 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32),
Dense(32 => 2; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32))
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-9.106483f-5; -0.00013703329; 1.0647951f-5; -9.399336f-5; -7.167876f-5; 3.22397f-5; 0.00010990998; 3.6660753f-5; -3.700739f-5; -8.2176186f-5; -4.7703477f-5; -0.00021417561; 6.709479f-5; -5.517798f-5; 2.9039576f-5; 1.7956616f-5; 2.3073078f-5; 7.5196294f-5; 0.00016941824; 8.117737f-5; 4.7875743f-5; 0.000116946794; 4.8856666f-5; -6.0917693f-5; 1.6190086f-5; -7.135145f-5; 2.7546614f-5; 4.2319596f-5; 9.941526f-5; -0.00023735237; -9.674531f-5; -9.546929f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[3.409279f-5 1.38265905f-5 -3.8573504f-5 0.00025541912 -4.4698416f-5 -2.8616698f-5 8.310035f-5 -0.00015323065 5.9766688f-5 -9.75961f-6 1.5265034f-5 -4.997671f-6 -8.117764f-5 0.00011592936 -2.2823091f-5 -0.00010421868 6.127755f-5 0.000111308334 9.395176f-5 -0.00010543964 -8.652912f-5 7.881796f-5 2.4231413f-5 -1.6318454f-5 1.0618442f-5 7.108511f-5 0.000106448264 -8.008665f-5 0.00011204211 7.4949303f-6 -0.00012282596 8.0800884f-5; 7.0322235f-6 -7.34972f-5 9.503138f-5 -4.0829265f-5 -0.00016326284 -8.270626f-5 -6.556648f-5 5.1359726f-5 -8.689507f-5 -1.2264289f-5 0.00013517142 6.931681f-5 -3.5591536f-5 -4.311031f-5 2.0984351f-5 0.0001802892 0.00012415947 -9.1469024f-5 5.0882172f-5 9.6069925f-5 0.00016210102 6.9922564f-5 -4.101245f-5 -8.6601336f-5 -8.554137f-6 -4.4065015f-5 2.734941f-5 -4.7469897f-5 -0.00016484785 -2.4910336f-5 -0.00018621459 -6.433533f-5; 3.646045f-5 6.137108f-5 -0.000100343925 2.8210372f-5 6.858548f-5 -6.629523f-5 0.000108677625 6.6408254f-5 2.5466696f-5 -0.00012157826 2.4738874f-5 0.00013375943 0.00010471074 -5.3011227f-5 -3.728043f-5 2.028235f-5 0.00014732215 -1.4151517f-5 8.661068f-5 1.5248058f-5 -2.6362384f-5 0.00011041911 -8.204823f-5 -0.00013561682 -0.00018027567 -0.00010830433 -0.0001648058 3.2379052f-5 -5.1039675f-5 -0.00016676249 5.579321f-5 -4.9139362f-5; 5.819366f-5 8.0277096f-5 2.1627257f-6 4.3690343f-5 -0.00012802865 5.5094897f-5 2.3283696f-6 -9.9071054f-5 -1.5664817f-5 0.00010581462 0.00013649373 5.56986f-5 -2.3608216f-5 -2.4964449f-5 6.511591f-5 7.422402f-5 5.653532f-5 -1.8782324f-5 -0.00010729828 -7.912977f-5 5.891438f-5 2.5209643f-5 -0.0002846946 2.0386113f-5 4.9214796f-5 -0.00013646159 -4.496451f-5 7.990376f-5 -0.00013622241 5.2069823f-5 5.3286978f-5 2.9645682f-5; -0.00018012316 -1.308541f-6 4.168893f-5 -0.00014149045 -0.00010193425 1.31669685f-5 4.3053897f-5 5.9686317f-5 -0.00010708086 -9.9266785f-5 -0.00021224651 -0.00013689458 -0.00016374167 0.00014963704 3.4560497f-5 3.1051928f-5 -8.9509555f-8 3.049641f-5 0.00020987426 2.2433348f-6 -3.5117784f-5 -0.00042839756 -0.00012921274 -1.3618142f-5 6.314127f-5 4.4409706f-5 -0.000234889 -7.931935f-5 1.3755291f-5 0.00014030318 -0.00011678421 -1.087993f-5; -0.00013306264 6.4934866f-6 8.5255415f-5 3.4734756f-5 -6.501932f-5 0.0002483612 -4.601828f-5 -0.0001839534 0.0002956941 4.599014f-5 -0.0001539515 5.765776f-5 -0.000117193005 -7.259512f-5 7.445684f-5 -7.746618f-5 -4.7599395f-5 -0.00025245792 7.080499f-5 -0.0001276052 -0.00017125909 2.7215085f-5 8.4325155f-5 -0.00010992595 -0.00024362106 -9.867182f-5 0.00012029006 5.9548645f-5 -5.787314f-6 0.00012841319 6.1948034f-5 3.711318f-5; 0.00010641905 3.0565083f-5 -0.00015357311 -6.1773935f-5 -4.9897357f-5 -7.7353674f-5 0.00016893762 -0.00013495529 -0.000100369274 -0.00012189466 -7.313658f-5 -7.486975f-6 4.9814153f-5 -6.800949f-5 -3.479003f-5 9.9174154f-5 -0.00011762143 2.451618f-5 -0.00013099254 0.00014140991 0.00018850251 6.9212896f-5 -2.5139725f-6 5.473167f-6 9.7319265f-5 0.00017480241 6.473985f-5 -1.6890444f-5 0.00013044664 1.7397626f-5 0.00018953452 0.00012530616; 0.00018008782 0.00015171788 -3.962373f-5 5.7438247f-6 -2.9629104f-5 -0.00013211551 -4.777413f-5 8.9404435f-5 5.7108984f-5 -7.110346f-5 -5.343813f-5 1.257604f-6 0.00011174503 -0.00013409728 -6.0856066f-5 0.00022668425 7.4093814f-6 5.253423f-5 5.960278f-5 -4.3886615f-5 0.00014285544 4.439935f-5 -8.802729f-6 -9.189847f-5 -2.7791699f-5 0.0001470197 -8.709496f-5 0.00014233869 0.00021418819 7.609549f-5 -9.459446f-5 4.6119498f-5; 0.0001041883 -3.4916861f-6 5.8636106f-6 -3.576258f-6 5.933026f-6 -2.7943666f-5 -5.9072245f-5 3.4977245f-5 5.2147137f-5 -7.1128657f-6 0.00011290618 -7.148453f-5 4.070048f-5 -0.00017108631 -0.000104335675 -8.6210006f-5 1.4632889f-5 -4.425758f-5 -0.00012063545 5.325348f-5 -1.3925006f-5 -4.00572f-5 9.350735f-5 2.8539862f-5 -8.047137f-5 8.0769234f-5 0.00016283152 0.00013231602 -6.440291f-6 9.5545234f-5 0.00013957047 1.2261067f-5; -5.628223f-5 -3.5199933f-5 -5.993784f-5 -0.00023398815 -0.0001855073 -0.00010522872 0.00021251824 0.00011920197 -0.0001923145 -0.00010994982 9.32239f-5 -3.8284277f-5 -0.00012965822 7.333779f-5 4.1037812f-5 -3.8516926f-5 -0.00012904391 -4.9743823f-5 -6.4289134f-6 4.8760692f-5 -2.965942f-5 -4.6756493f-5 -3.1719403f-5 9.117879f-6 -6.5620654f-5 3.148789f-5 1.5998048f-5 9.28991f-5 1.2223345f-5 -0.00015057366 7.429329f-5 4.3126915f-5; -1.4801366f-5 -6.0063714f-5 -8.4170635f-5 0.00011874647 3.071896f-6 -4.3614233f-5 2.2180964f-5 -1.611298f-5 -1.3759409f-5 -6.934523f-5 -2.8202148f-5 -3.8781167f-5 -0.00012155408 -3.254021f-5 -0.000109993096 6.788526f-5 -0.00017156219 0.00017800808 -4.5713296f-6 5.5101344f-5 0.00010848662 -6.5378736f-5 -0.00030882593 -0.00015075327 0.00016215174 -1.0478718f-5 -6.7561334f-5 -0.00010668413 -4.2789277f-5 -3.065223f-5 3.4620272f-5 -8.9829606f-5; -5.2238578f-5 7.4702635f-5 9.986429f-5 3.5260407f-6 4.0058818f-5 8.599951f-5 4.478991f-6 -0.00013205322 4.2600324f-5 -4.9344642f-5 0.00021312384 0.00017601211 6.180301f-5 -1.321815f-5 0.00015666905 3.9179333f-5 -8.9101886f-5 -1.1533673f-5 7.0716815f-5 -0.00017358607 7.734968f-5 -2.1533704f-6 -7.8910576f-5 -0.00011796672 2.8596152f-5 -3.0290135f-5 6.8797344f-5 0.0002242371 0.00018898943 -1.628186f-5 4.3043906f-6 -6.153557f-5; 1.0241612f-5 -3.8286624f-5 0.00011174086 -2.1334126f-5 -6.108365f-5 0.00012473353 -8.4354186f-5 7.693881f-5 -8.4931475f-5 0.00022587152 -8.9079236f-5 -2.7485643f-5 -2.5489782f-5 -5.2289448f-5 -1.22850315f-5 2.5157535f-5 -9.682391f-6 1.635988f-5 -0.00011232891 -7.4590745f-5 2.6057258f-5 -1.0367329f-5 -0.00016086923 -0.0001950892 -8.0175625f-5 -1.07811475f-5 9.6042415f-5 0.0001673539 -0.00018875078 -2.5478312f-5 3.5504988f-6 8.435957f-5; -0.00015799781 5.683755f-6 7.2997536f-5 -0.00014035452 -1.5893549f-5 -3.798799f-5 0.00014241989 0.00018419695 0.00014328204 0.00010672899 -3.0525975f-5 -6.391091f-5 -0.00015753244 -0.00014245164 6.502888f-5 -0.00017187951 -1.931983f-5 8.314586f-5 -2.2894606f-5 5.387625f-6 1.3516479f-5 3.2287953f-5 9.449488f-5 0.000108292064 4.8061185f-5 -1.9899328f-5 -0.0001308993 -4.1417297f-5 1.3014289f-5 -2.0047904f-5 -6.2604313f-6 6.0502072f-5; -2.8797424f-5 -8.247879f-5 -3.5444962f-5 3.0237894f-5 -1.1138215f-5 -4.499256f-6 -8.219493f-5 1.3128087f-5 0.00010766595 -0.00015622907 2.8441631f-5 -7.387948f-5 -0.000108637534 -6.204475f-5 3.6240614f-5 -3.9931056f-5 0.000114407354 -0.00016217011 1.6394617f-5 3.112688f-5 -0.00015150731 -0.00028460103 8.364025f-6 0.00017535897 4.137068f-5 9.6441814f-5 2.469417f-5 -9.500595f-5 0.00015576858 6.9843016f-5 -0.00011170346 -6.980894f-5; 0.00018190654 -0.0001480987 -0.00012562236 0.00014152606 7.5499345f-5 3.945816f-5 -7.507963f-5 -0.000105105995 4.6955142f-6 0.00017472445 -3.2931213f-5 -0.00015917797 7.925976f-6 -6.855978f-5 -0.00013110471 -8.823695f-5 4.3106993f-5 0.000115125986 6.244611f-5 4.6641075f-5 -0.00015516294 -0.00020745411 -0.00021340787 5.733618f-6 0.000118703756 0.00019720408 0.0001364211 -3.9934166f-5 -5.7479214f-5 -8.198869f-5 -4.1084795f-6 -6.7383786f-5; 8.9856665f-5 -0.00019555086 1.1586149f-6 0.00025681563 -1.5281525f-5 -0.00012787056 -0.00014725597 0.00011075036 1.2743101f-5 -6.375082f-5 0.00015331083 -9.434826f-5 5.1670708f-5 5.741999f-6 -0.0001986923 -4.31617f-5 -1.4294788f-5 -0.00010946372 0.00010347143 2.6183156f-5 0.00014376496 -8.9296245f-6 -0.00015775248 9.954849f-5 -0.0001547561 3.661258f-5 0.0001620128 -6.960361f-5 -6.3539424f-5 -0.00012349278 0.00018393814 6.691048f-5; -9.7848926f-5 0.00010688668 2.159343f-5 7.794842f-5 5.577541f-5 -0.00016420851 -0.00018098461 7.889311f-5 -3.6791353f-5 -3.0131112f-5 8.5875785f-5 4.1045616f-5 -8.982642f-5 -0.00011088922 0.00014230981 -0.00015063582 -2.5491372f-5 -5.3660773f-5 0.00011551379 0.00011048723 -7.718673f-5 -0.00016002644 0.000104442996 0.0002020345 -2.3942244f-5 8.303071f-5 7.114059f-7 -6.893028f-5 2.7509743f-5 -4.6649642f-5 -2.7034783f-5 3.7618254f-5; 0.00014617911 -0.00014200828 -5.0149665f-5 9.73183f-5 4.7217476f-5 -3.503267f-5 1.1653027f-5 -0.00010518211 1.1660775f-5 9.5198586f-5 -0.00012426799 6.5325934f-5 5.5668206f-5 0.00014342573 -8.4455445f-5 -3.2751122f-5 -3.0891762f-5 -7.313492f-5 -2.5965857f-5 6.005243f-5 0.000108158914 0.00028660725 -3.216497f-5 3.554473f-5 0.00010127056 -2.2500305f-6 -6.778313f-5 -1.8794854f-5 -6.197592f-5 -0.00011895406 -7.617648f-5 -3.6750203f-5; -0.00011059225 9.2640395f-5 -0.00012611975 -2.6946014f-5 -0.00017863952 9.7976714f-5 -3.7936803f-5 7.493955f-5 -3.633735f-5 -2.4824316f-5 -2.893001f-5 3.9087005f-5 2.9626555f-5 8.040248f-5 0.00021074369 1.3264995f-5 -1.783349f-5 3.554457f-5 -8.922172f-5 1.06118005f-5 -0.00013921494 0.0001564029 0.00011144703 5.9777354f-5 0.0001498234 4.577696f-5 -0.00018925074 7.665232f-5 -5.7828984f-5 -8.115157f-5 -7.058409f-5 6.8517096f-5; -7.746627f-5 -0.0001871907 0.00016120909 4.565481f-6 -7.7396355f-5 0.00017110757 -4.7768255f-5 -4.5524554f-5 0.00016367999 0.00020235805 -1.525371f-5 7.682032f-5 3.185958f-5 7.845409f-5 -0.000108541964 -3.477449f-5 7.529727f-5 2.1644568f-5 -0.00012082392 1.4739273f-5 -7.164562f-5 -4.3803557f-5 7.330212f-5 2.4374247f-5 -0.00015467995 7.6694654f-5 -2.2451434f-6 0.00012762671 8.608951f-5 0.00014439358 8.609201f-6 -7.0604715f-6; -3.1813215f-5 0.00014757678 7.30702f-5 -0.000101640464 -3.1559925f-5 -0.00010503576 8.378376f-7 -2.4047087f-5 -6.107883f-5 4.0441962f-5 8.046604f-6 4.7784113f-5 -5.146147f-5 0.00010208261 -5.2016294f-5 -0.0002592163 -7.448122f-5 0.00012395624 -0.00016794344 0.00016726088 -3.238296f-5 9.544915f-5 -1.8373561f-5 2.9271201f-5 0.00012364567 -2.371423f-5 -1.613679f-5 3.658526f-5 -4.477964f-5 -3.9880644f-5 -9.4829564f-5 1.4366582f-5; -0.00014612434 -6.1715815f-5 -1.7585442f-5 -8.3691f-5 -7.007863f-5 0.00013104432 3.438746f-5 0.00011813031 8.916577f-5 -8.555951f-6 6.7185705f-5 0.00015744631 4.320227f-6 5.178699f-5 0.00016587085 -9.584116f-5 7.798433f-5 1.9292433f-5 1.03240026f-7 -4.7290552f-5 -0.00014452635 -0.000109657194 -1.195879f-5 0.000108832115 -9.2329574f-5 0.00010395199 -2.4130064f-5 0.00025711954 7.3371964f-5 -0.00011676899 -1.3328394f-5 -4.7604775f-5; 2.0019796f-5 -0.00016852646 8.559341f-5 1.4416565f-5 -0.000103649705 3.2820382f-5 -9.0653586f-5 0.00015115767 -0.00010012144 8.38391f-5 0.00017897828 2.9417564f-5 -0.00011427061 0.0003321899 8.481061f-5 -0.00016472048 -4.2556487f-5 3.0099469f-5 -0.00022038762 -0.00026990467 -5.894521f-5 6.00233f-5 6.055453f-8 -0.000103486534 0.00017839257 -8.35268f-5 2.6007148f-5 -0.00010286733 -2.3422732f-5 -2.2563217f-6 0.00011921658 5.30098f-5; 2.3348844f-5 3.604623f-5 8.67744f-5 0.00021189696 1.5691921f-5 -4.6824007f-5 -3.492487f-5 2.7596772f-5 0.0001183519 3.2885113f-5 2.284798f-5 -3.1042004f-5 6.4631f-5 3.4055596f-5 7.817448f-5 -2.6067997f-5 8.001481f-5 8.2300045f-5 0.00015287829 4.192782f-5 -0.00017660153 0.00019750929 -0.00013346424 -0.00015532286 0.00010011041 8.356955f-5 9.321625f-5 -1.1643716f-5 0.00013033643 6.56447f-5 4.7558115f-5 0.00013482594; -3.91767f-5 7.700568f-6 -0.0001185786 4.4215557f-7 3.3117285f-5 0.00016941223 1.4274308f-5 2.2931968f-5 -4.5713474f-5 3.7351794f-5 9.607333f-5 0.00018051703 -1.3076852f-5 0.00014160428 7.96112f-5 -6.401511f-5 -0.00011726349 -1.537543f-5 7.407969f-5 -2.4584508f-6 9.141959f-6 -0.00014013963 -2.555742f-5 -8.080161f-5 -0.00014072409 -2.1424454f-5 0.00012757412 -2.4589794f-5 -0.00014655024 0.00017502486 0.00020167996 -1.748596f-5; -0.00013108552 -3.368623f-5 -3.3569395f-5 5.729566f-5 -2.1676111f-5 0.00012596823 -8.139518f-5 -0.00023048109 -0.00011592543 3.6174857f-5 0.000107389395 -5.9029423f-7 8.572117f-6 -2.2858792f-5 -8.105153f-6 1.0572617f-5 -6.7559718f-6 -0.00022278541 0.00015381297 6.281469f-5 0.00012839843 -0.0001395971 2.6728401f-5 5.1307412f-5 -5.318831f-5 2.5101765f-6 5.977508f-5 -8.571643f-5 3.1448606f-5 2.2379661f-5 3.6426456f-5 -1.3724931f-5; 0.00014587906 0.00013772136 -0.00012261023 8.152572f-5 -2.3684637f-5 0.00011664137 -2.2002316f-5 7.200345f-5 -9.0214526f-5 5.7705114f-5 -0.00012112146 9.13321f-5 2.5115529f-5 5.9500668f-5 8.153258f-5 0.00018217358 -9.384078f-5 -5.552578f-6 -5.4035f-5 0.00013208787 -0.00015828917 0.00019338606 -8.149844f-5 0.00014916125 0.00019622609 0.00017246061 -1.3819726f-5 -6.874983f-5 0.00012616589 -0.00011582322 -1.8563836f-5 -7.829109f-5; 5.8908205f-5 -0.00012822931 -3.8258186f-6 0.00013244669 4.389325f-5 -4.59058f-5 -5.1414776f-5 -0.00010760851 -1.2172934f-5 0.00024728727 -4.0660296f-5 0.00017909492 -8.521957f-6 -0.000120362165 -6.756756f-5 8.5160915f-5 -1.52915f-5 -0.00010638268 0.000107518295 -1.46455295f-5 9.632722f-5 -4.0501167f-5 4.9044927f-5 -2.9876473f-5 5.1831237f-5 2.2192711f-5 -7.754864f-5 7.290985f-5 0.0001339954 9.6881544f-5 -5.5799952f-5 -7.3818305f-6; -7.934372f-6 -0.000114471586 6.1144956f-6 5.7925332f-5 -1.09768f-5 -3.7164675f-5 3.615773f-5 4.7817575f-5 8.5008876f-5 -9.028778f-5 -4.8498485f-5 2.619344f-5 -7.002179f-5 -1.777481f-5 -3.099093f-5 0.00019743692 -3.7834227f-5 -2.9609262f-5 1.8811019f-5 5.0755425f-5 0.00014773429 5.621369f-5 -0.00014504213 -8.715143f-5 8.926851f-5 -0.0001272147 -2.4769653f-5 -1.9215462f-5 4.3706863f-5 -0.00014794405 9.887155f-5 -0.000113635746; -7.328712f-5 0.00011838675 7.868176f-5 -6.538742f-5 3.26092f-5 -5.073025f-5 -3.8468657f-5 -3.4809953f-5 9.431314f-5 -7.2933115f-5 -7.8569836f-5 6.2495536f-5 -9.649892f-5 0.0001377718 -0.000110046225 5.1349343f-5 -7.745138f-5 0.00012356848 -0.00023877127 -0.00015943694 7.966769f-5 -2.7064525f-5 5.3260374f-5 9.78728f-5 -1.9376292f-5 0.00011191838 -0.00012611928 2.221507f-5 8.090958f-6 7.3453324f-5 0.000258359 0.00014288128; -0.00013517107 -6.602396f-5 -7.9907535f-5 0.00012663302 1.1115857f-8 0.00011151947 3.2980795f-5 2.2351338f-5 8.459579f-5 8.751299f-6 -1.5471047f-5 -0.00011194026 1.7939323f-5 -0.00014748558 0.00010220331 5.852279f-5 -1.984863f-5 -4.3934528f-5 0.00015320611 -1.7365757f-5 1.4679871f-5 0.00011391602 7.120811f-5 1.8749992f-6 -0.00011688845 -2.2471573f-5 -5.9698235f-5 -1.5737389f-5 0.00010888079 -0.0002786921 6.2514055f-6 -4.0113322f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[-3.2247547f-6 0.00020973592 0.00014815881 5.4111465f-6 4.2445365f-5 4.944919f-5 -3.908078f-5 -5.3296397f-5 7.1942004f-6 -0.00021342214 5.7447935f-5 4.634191f-6 -9.27962f-5 -2.9991163f-5 -0.00014941843 0.00020431582 8.41824f-5 -1.3391325f-7 -9.494906f-6 1.4045917f-5 -3.44961f-5 4.4333436f-5 -3.7272148f-5 0.0001187127 2.4289004f-5 -0.00015269149 -5.8588215f-5 0.00022618052 -1.5949627f-5 2.9853195f-6 -8.662131f-5 8.979611f-7; -2.8530836f-5 -5.2006268f-5 -9.111273f-5 -0.00016493234 1.6379716f-5 0.00010915025 -0.00012949077 -0.00013625609 6.530208f-5 0.00016703292 5.7105943f-5 -4.9472434f-5 3.7414004f-5 -0.00011802978 -0.00012529017 -7.880395f-5 -7.877614f-6 -3.9498114f-5 0.0001822672 0.00012599745 0.00023978624 -4.7785856f-5 4.819839f-5 3.7724843f-5 -2.0755779f-5 8.9100824f-5 -8.8604735f-5 3.8819475f-5 7.80915f-5 -0.000111455294 -5.793465f-6 -8.24349f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))
Similar to most DL frameworks, Lux defaults to using Float32
, however, in this case we need Float64
const params = ComponentArray(ps |> f64)
const nn_model = StatefulLuxLayer{true}(nn, nothing, st)
StatefulLuxLayer{true}(
Chain(
layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
layer_2 = Dense(1 => 32, cos), # 64 parameters
layer_3 = Dense(32 => 32, cos), # 1_056 parameters
layer_4 = Dense(32 => 2), # 66 parameters
),
) # Total: 1_186 parameters,
# plus 0 states.
Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses
where,
function ODE_model(u, nn_params, t)
χ, ϕ = u
p, M, e = ode_model_params
# In this example we know that `st` is am empty NamedTuple hence we can safely ignore
# it, however, in general, we should use `st` to store the state of the neural network.
y = 1 .+ nn_model([first(u)], nn_params)
numer = (1 + e * cos(χ))^2
denom = M * (p^(3 / 2))
χ̇ = (numer / denom) * y[1]
ϕ̇ = (numer / denom) * y[2]
return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)
Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)
axislegend(ax, [[l1, s1], [l2, s2]],
["Waveform Data", "Waveform Neural Net (Untrained)"]; position=:lb)
fig
end
Setting Up for Training the Neural Network
Next, we define the objective (loss) function to be minimized when training the neural differential equations.
const mseloss = MSELoss()
function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
return mseloss(pred_waveform, waveform)
end
loss (generic function with 1 method)
Warmup the loss function
loss(params)
0.0007281560239842478
Now let us define a callback function to store the loss over time
const losses = Float64[]
function callback(θ, l)
push!(losses, l)
@printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
return false
end
callback (generic function with 1 method)
Training the Neural Network
Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
optprob, BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
callback, maxiters=1000)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-9.106483048511145e-5; -0.00013703329022955258; 1.0647951057761933e-5; -9.399335976907616e-5; -7.167876174202258e-5; 3.223969906684196e-5; 0.00010990998271157609; 3.6660752812184353e-5; -3.700739034679843e-5; -8.217618596964835e-5; -4.7703477321158876e-5; -0.00021417561219963153; 6.70947920296687e-5; -5.5177981266760396e-5; 2.9039576475025662e-5; 1.7956615920395966e-5; 2.3073078409647086e-5; 7.519629434677248e-5; 0.0001694182428765159; 8.117737161230538e-5; 4.787574289367655e-5; 0.00011694679415044511; 4.885666567128151e-5; -6.091769319026973e-5; 1.619008617125749e-5; -7.135145278869661e-5; 2.7546613637214294e-5; 4.231959610472276e-5; 9.941525786393147e-5; -0.00023735237482454534; -9.674530883778696e-5; -9.546928777117981e-5;;], bias = [-1.748400022200869e-16, -2.1135192303534867e-16, 1.0300607340323707e-18, -1.7209139304774143e-16, -6.237725811652361e-17, 7.73303045908833e-18, -3.681711036818798e-17, -1.9158143170511834e-17, -5.1260987504682864e-17, -8.972622118959062e-17, -3.8170409055485344e-17, -4.738213129675166e-16, 1.7162016324614992e-16, 1.0083811165158888e-17, 1.9261832438937458e-17, 1.7601386901639714e-17, 5.526118194645539e-18, 6.011908796497967e-17, -1.0885369229055314e-16, -2.5269772774167365e-17, -1.0392074910907735e-17, 4.723723561056664e-16, 6.25864752691127e-17, -5.506497074942992e-18, 7.081099948470982e-18, -1.2075294538126676e-16, 3.0741438722287256e-17, 7.590828695730738e-17, 2.649692774369707e-16, -1.7218247485217108e-16, -7.605836600684566e-17, -9.894082414854222e-17]), layer_3 = (weight = [3.40945870568491e-5 1.3828386648328846e-5 -3.857170767038774e-5 0.000255420912174692 -4.469661970394671e-5 -2.861490164538613e-5 8.310214517516602e-5 -0.00015322885254751679 5.9768483800536296e-5 -9.757814120141407e-6 1.526682982624203e-5 -4.995874939411306e-6 -8.117584464959771e-5 0.00011593115679100272 -2.2821294884823622e-5 -0.00010421688156090402 6.12793436750796e-5 0.00011131013062661477 9.395355929265279e-5 -0.00010543784545622857 -8.652732602674355e-5 7.881975663005798e-5 2.4233208853939216e-5 -1.6316657899342648e-5 1.062023776578995e-5 7.108690525749694e-5 0.00010645006015077482 -8.008485118513738e-5 0.00011204390367604507 7.496726460154445e-6 -0.00012282416818086315 8.080268063461168e-5; 7.031054680368946e-6 -7.34983669857694e-5 9.50302081733672e-5 -4.083043415758333e-5 -0.00016326400975540928 -8.270743012272622e-5 -6.556764594588815e-5 5.135855724492253e-5 -8.689623530099958e-5 -1.2265457483778042e-5 0.00013517024751908321 6.931564246461686e-5 -3.559270510177798e-5 -4.31114796114413e-5 2.0983182254523338e-5 0.00018028802412717708 0.0001241583039487444 -9.147019329557821e-5 5.0881003404904704e-5 9.606875653534536e-5 0.0001620998472177162 6.992139503981595e-5 -4.101361730445927e-5 -8.66025048921141e-5 -8.555305724629583e-6 -4.406618347584917e-5 2.734824090711427e-5 -4.747106606683536e-5 -0.00016484901891399643 -2.4911504736688198e-5 -0.00018621575611487518 -6.433649570990851e-5; 3.645947194686251e-5 6.137010254692706e-5 -0.0001003449021044204 2.820939457537714e-5 6.858450555969023e-5 -6.629620716180791e-5 0.00010867664727643342 6.640727710705122e-5 2.5465718270557592e-5 -0.00012157923870459157 2.4737896954453494e-5 0.00013375845205533665 0.00010470976608134778 -5.3012203945266836e-5 -3.728140896042022e-5 2.0281372915765523e-5 0.00014732117174221052 -1.4152493942540531e-5 8.660970101860155e-5 1.5247080544347779e-5 -2.6363361623092604e-5 0.0001104181331794735 -8.204920836929327e-5 -0.0001356177987434907 -0.0001802766516010293 -0.00010830530905136847 -0.00016480677245622649 3.2378074454887555e-5 -5.104065226246463e-5 -0.00016676346663356087 5.579223338659045e-5 -4.914033948066528e-5; 5.81936806506138e-5 8.027711786638339e-5 2.16274796769827e-6 4.369036518113088e-5 -0.00012802863008034082 5.509491942433047e-5 2.3283919639953325e-6 -9.907103181973692e-5 -1.5664794283626953e-5 0.0001058146449744828 0.0001364937566344753 5.569862381747297e-5 -2.3608193679653487e-5 -2.496442672196905e-5 6.51159357752778e-5 7.422404500033115e-5 5.6535340753210364e-5 -1.8782302014703316e-5 -0.0001072982554083565 -7.912974930292592e-5 5.891440063208679e-5 2.5209665671748597e-5 -0.0002846945785327688 2.0386135055044294e-5 4.9214817898640064e-5 -0.000136461566818586 -4.4964486243376846e-5 7.990378120524216e-5 -0.0001362223915398926 5.206984545878432e-5 5.328700039029453e-5 2.9645704251773246e-5; -0.00018012860030940607 -1.3139856811786787e-6 4.168348592863962e-5 -0.00014149589927474744 -0.00010193969731991608 1.3161523806755521e-5 4.304845193910281e-5 5.9680872677544294e-5 -0.00010708630228256974 -9.927222939514697e-5 -0.0002122519531760365 -0.00013690002609576768 -0.00016374711643512904 0.00014963159398101404 3.4555051993063857e-5 3.1046483368097095e-5 -9.495427739025028e-8 3.0490965823253033e-5 0.00020986881051142714 2.237890125606232e-6 -3.512322885873729e-5 -0.000428403002420234 -0.0001292181864621116 -1.3623586287591441e-5 6.313582477084933e-5 4.4404261156874985e-5 -0.00023489444223307346 -7.932479141820314e-5 1.3749845916022929e-5 0.00014029773361915975 -0.00011678965760002528 -1.0885375136427628e-5; -0.00013306456895902644 6.4915577100294614e-6 8.525348601737591e-5 3.473282702957427e-5 -6.502124671483498e-5 0.000248359275414754 -4.602020709667004e-5 -0.00018395532316886214 0.00029569217366274127 4.598820958976651e-5 -0.0001539534358157244 5.765582971851024e-5 -0.00011719493415098762 -7.259704654230022e-5 7.44549099040741e-5 -7.746811099009974e-5 -4.760132361794265e-5 -0.0002524598465383431 7.080306312178725e-5 -0.0001276071350871026 -0.0001712610181270566 2.721315575813861e-5 8.432322573257351e-5 -0.00010992788230892447 -0.0002436229922891075 -9.867375121343967e-5 0.00012028813142437127 5.954671649719293e-5 -5.789242837135876e-6 0.00012841125788743947 6.194610540164816e-5 3.7111250262159646e-5; 0.00010642124229454476 3.0567274138966966e-5 -0.0001535709222069041 -6.17717438715603e-5 -4.9895165516116806e-5 -7.735148270037401e-5 0.00016893981350696416 -0.0001349531000003258 -0.00010036708286386381 -0.00012189247232052348 -7.313438859889748e-5 -7.484783650225301e-6 4.981634387624759e-5 -6.800729775457229e-5 -3.478784006326821e-5 9.917634574342194e-5 -0.0001176192369300906 2.451837039663684e-5 -0.00013099035154918691 0.00014141210098861872 0.00018850470346044285 6.921508733759292e-5 -2.5117812172324036e-6 5.475358507485526e-6 9.732145676500773e-5 0.00017480459914188167 6.474203959917295e-5 -1.688825309863743e-5 0.00013044883357471717 1.7399817211463776e-5 0.00018953671073651838 0.00012530835496838186; 0.00018009124811782336 0.00015172129984987502 -3.962030553680099e-5 5.747249239210588e-6 -2.9625679084125902e-5 -0.00013211208955340374 -4.777070592696432e-5 8.940785971048899e-5 5.711240861842082e-5 -7.110003446883516e-5 -5.343470423554497e-5 1.2610286068534148e-6 0.00011174845473419202 -0.00013409385668069638 -6.0852641248631166e-5 0.00022668767393611164 7.412805977081887e-6 5.2537653906477043e-5 5.960620307288491e-5 -4.388319087753047e-5 0.00014285886333368288 4.440277313390914e-5 -8.799304517882137e-6 -9.189504874826064e-5 -2.7788274231853123e-5 0.0001470231267069206 -8.709153429094957e-5 0.0001423421102720089 0.00021419160963545212 7.609891814916473e-5 -9.459103842688075e-5 4.612292229125021e-5; 0.00010418965401669708 -3.4903327238573495e-6 5.864963973472937e-6 -3.5749045898077453e-6 5.934379337594133e-6 -2.7942312876511664e-5 -5.9070891361156986e-5 3.49785987830912e-5 5.214849087253261e-5 -7.1115122543073055e-6 0.00011290753180703154 -7.148317664787225e-5 4.0701834300597224e-5 -0.0001710849609614166 -0.00010433432172674615 -8.620865256914775e-5 1.463424276636999e-5 -4.4256226751423916e-5 -0.00012063409659703489 5.325483296957617e-5 -1.3923652581688867e-5 -4.005584552458983e-5 9.350870103418829e-5 2.8541215292824817e-5 -8.04700174873499e-5 8.077058745158721e-5 0.00016283287558082566 0.00013231737837970286 -6.438937739116152e-6 9.554658762253303e-5 0.0001395718282631754 1.2262420646326479e-5; -5.628547767022964e-5 -3.52031794803511e-5 -5.994108612262928e-5 -0.000233991392791125 -0.00018551055215843703 -0.0001052319650383219 0.00021251499022279 0.00011919872379316797 -0.00019231774892737306 -0.00010995306472372768 9.322065200344607e-5 -3.828752339659293e-5 -0.00012966146566518398 7.33345430732242e-5 4.103456633551771e-5 -3.8520172141355285e-5 -0.00012904715656382278 -4.9747069327483577e-5 -6.432159510331825e-6 4.87574460081797e-5 -2.9662666905415757e-5 -4.675973939205281e-5 -3.172264875806908e-5 9.114632872888896e-6 -6.562390005761094e-5 3.1484642774241675e-5 1.5994802039719825e-5 9.289585325554838e-5 1.2220098894120372e-5 -0.00015057690431902955 7.42900436549984e-5 4.312366838954138e-5; -1.4805223884881164e-5 -6.0067571951767e-5 -8.417449299783137e-5 0.0001187426152562211 3.0680379829771842e-6 -4.361809061562219e-5 2.217710624495469e-5 -1.6116837117136494e-5 -1.3763266736771913e-5 -6.934908907465427e-5 -2.8206006221136806e-5 -3.878502485640814e-5 -0.00012155794121395868 -3.254406857491727e-5 -0.00010999695388205881 6.788140143429262e-5 -0.00017156604900561841 0.0001780042204579718 -4.575187461546768e-6 5.50974856985193e-5 0.00010848276559657998 -6.538259350531588e-5 -0.00030882978336491026 -0.00015075713037118486 0.00016214788610044263 -1.0482575647309406e-5 -6.756519161468914e-5 -0.0001066879866023217 -4.279313526535735e-5 -3.065608854120681e-5 3.46164144382307e-5 -8.983346361970023e-5; -5.2235139954466396e-5 7.470607328427342e-5 9.986773169800248e-5 3.5294788099752766e-6 4.0062255940539545e-5 8.600295063076046e-5 4.482428941221613e-6 -0.00013204977927528855 4.2603762487093245e-5 -4.934120421097816e-5 0.00021312728153276555 0.00017601554667772806 6.180644799456981e-5 -1.3214712233490053e-5 0.0001566724861200479 3.9182770925951696e-5 -8.909844815064659e-5 -1.1530235277349926e-5 7.072025277150933e-5 -0.00017358263335921987 7.735311759603504e-5 -2.149932250585274e-6 -7.890713801174141e-5 -0.00011796328522763079 2.859959044287038e-5 -3.0286697218287906e-5 6.880078239334283e-5 0.00022424053553692414 0.00018899287191884973 -1.6278421534026667e-5 4.307828695851795e-6 -6.153212940120602e-5; 1.0240010592028677e-5 -3.82882247190619e-5 0.000111739260059476 -2.133572732920107e-5 -6.108525147784974e-5 0.0001247319302951662 -8.435578664422146e-5 7.693721246549187e-5 -8.493307567325161e-5 0.00022586992391960308 -8.908083717462298e-5 -2.7487243901753957e-5 -2.5491383244586586e-5 -5.2291048892146216e-5 -1.2286632486865852e-5 2.5155933794376634e-5 -9.683991525452743e-6 1.635827876678687e-5 -0.00011233051217050462 -7.459234646877672e-5 2.605565688507475e-5 -1.0368930166361649e-5 -0.00016087082648074872 -0.00019509080474155576 -8.017722601427491e-5 -1.0782748436683422e-5 9.604081398218169e-5 0.00016735229265985363 -0.00018875238336978705 -2.547991269738067e-5 3.5488978614790067e-6 8.435796891215924e-5; -0.00015799781345203297 5.683753848922129e-6 7.299753537727746e-5 -0.00014035451845638232 -1.5893550020058258e-5 -3.798799030272997e-5 0.00014241988442000223 0.00018419694874529017 0.00014328204174153728 0.00010672899129195365 -3.052597553841545e-5 -6.391091397171613e-5 -0.0001575324432030298 -0.00014245163861568796 6.502887749439331e-5 -0.00017187951112072693 -1.93198312023865e-5 8.314585915860014e-5 -2.28946074036626e-5 5.387624193014262e-6 1.3516478198931934e-5 3.228795225941801e-5 9.449487601733061e-5 0.00010829206343833583 4.8061184267050036e-5 -1.9899328484546943e-5 -0.00013089929512142913 -4.141729828342571e-5 1.3014287966422188e-5 -2.004790535801796e-5 -6.260432249922386e-6 6.050207134313175e-5; -2.879988910903168e-5 -8.24812586452303e-5 -3.544742772105288e-5 3.023542835395417e-5 -1.1140680684586641e-5 -4.501721790844499e-6 -8.219739443487593e-5 1.3125621122484646e-5 0.00010766348081774973 -0.00015623153645177905 2.8439165404010936e-5 -7.38819425841998e-5 -0.00010863999973848243 -6.204721917496896e-5 3.6238147955331175e-5 -3.9933521251484086e-5 0.00011440488837793366 -0.00016217257412248914 1.6392151379457887e-5 3.112441576918971e-5 -0.00015150977465422787 -0.0002846034976337568 8.361559635845944e-6 0.00017535650935856777 4.136821466632063e-5 9.643934849551211e-5 2.4691705262697385e-5 -9.500841311277823e-5 0.00015576610989770568 6.984055020984923e-5 -0.00011170592820602792 -6.981140809697443e-5; 0.00018190478469862188 -0.00014810044929233198 -0.0001256241140033858 0.0001415243113335463 7.549759350169791e-5 3.945640667026972e-5 -7.50813833981493e-5 -0.00010510774707270659 4.693762479492125e-6 0.0001747227014895758 -3.2932965110175373e-5 -0.00015917972457441654 7.92422445030956e-6 -6.856152874732165e-5 -0.0001311064607541812 -8.8238701278697e-5 4.310524120612151e-5 0.00011512423400219902 6.244435819484103e-5 4.6639322976464914e-5 -0.00015516469109186725 -0.00020745586341560927 -0.00021340961946023445 5.731866098559981e-6 0.0001187020042600303 0.00019720232551137323 0.00013641935395227839 -3.993591788073849e-5 -5.748096562072024e-5 -8.19904435736549e-5 -4.110231271419925e-6 -6.738553754601063e-5; 8.985631745584167e-5 -0.00019555121082536902 1.1582673067354992e-6 0.00025681528665866574 -1.528187215460176e-5 -0.00012787090793442204 -0.00014725631456546587 0.00011075001554839094 1.2742753707997167e-5 -6.375116766154926e-5 0.0001533104822299981 -9.434860669731812e-5 5.1670360218694464e-5 5.741651304405272e-6 -0.0001986926491810404 -4.316204843068726e-5 -1.4295135886797175e-5 -0.00010946406667078086 0.00010347108206754637 2.618280802546839e-5 0.00014376461501508742 -8.929972071355083e-6 -0.0001577528292984164 9.954814038049236e-5 -0.00015475644443374356 3.661223115325728e-5 0.00016201245477698517 -6.960395524235657e-5 -6.353977198902642e-5 -0.00012349312430902918 0.00018393779634773603 6.69101299279915e-5; -9.784914278386727e-5 0.00010688646010860309 2.1593213879189737e-5 7.794820134333706e-5 5.57751934513577e-5 -0.0001642087284006054 -0.0001809848311120512 7.889288985208926e-5 -3.6791569308458456e-5 -3.0131328693438502e-5 8.587556816676024e-5 4.104539924922099e-5 -8.982663553868082e-5 -0.00011088943416214218 0.00014230959799644847 -0.00015063604065178586 -2.5491588766068838e-5 -5.366098981428545e-5 0.00011551357032769869 0.00011048701138972659 -7.718694383472812e-5 -0.00016002665348312416 0.00010444277881217514 0.0002020342815205581 -2.394246099245426e-5 8.303049411663198e-5 7.111892079804189e-7 -6.893050000033021e-5 2.750952602693722e-5 -4.664985886736069e-5 -2.7034999455987473e-5 3.761803758718445e-5; 0.0001461797064806865 -0.00014200768871851753 -5.01490694474348e-5 9.731889351350894e-5 4.7218071831956156e-5 -3.5032074246920196e-5 1.165362302380956e-5 -0.00010518151337200923 1.1661371009176052e-5 9.519918150581168e-5 -0.00012426739065595136 6.532652971181001e-5 5.5668801924546946e-5 0.000143426324048413 -8.44548494384079e-5 -3.27505267519889e-5 -3.0891166697620794e-5 -7.31343252296955e-5 -2.5965261538300954e-5 6.0053024772421753e-5 0.00010815951010990721 0.0002866078421178794 -3.216437560659071e-5 3.55453257910981e-5 0.00010127115730974698 -2.249434810075202e-6 -6.778253098637001e-5 -1.8794257793707466e-5 -6.197532728001787e-5 -0.00011895346777201918 -7.61758865151957e-5 -3.674960769749546e-5; -0.00011059168504921936 9.264095965867453e-5 -0.00012611918250236725 -2.694545005120563e-5 -0.00017863895086243906 9.7977277940157e-5 -3.7936238774781174e-5 7.494011183829891e-5 -3.6336786840326095e-5 -2.4823751707078893e-5 -2.8929445440514125e-5 3.908756875992776e-5 2.962711946842581e-5 8.040304633417982e-5 0.00021074425272039993 1.3265558985079887e-5 -1.7832926138868346e-5 3.554513419025411e-5 -8.922115707442755e-5 1.0612364670162112e-5 -0.00013921437825679885 0.0001564034576886489 0.00011144759296534477 5.977791838908437e-5 0.00014982396928921274 4.577752438585118e-5 -0.0001892501783431329 7.665288681259284e-5 -5.78284194916325e-5 -8.115100685878827e-5 -7.058352243526143e-5 6.851766043204584e-5; -7.746405681648984e-5 -0.00018718848805696298 0.00016121129810524077 4.567693663892214e-6 -7.739414213977286e-5 0.0001711097873299646 -4.776604255715424e-5 -4.552234095574546e-5 0.0001636821987591013 0.00020236025811352183 -1.5251497829400713e-5 7.682252980479338e-5 3.186179382789462e-5 7.845630243134372e-5 -0.00010853975182219562 -3.4772277289835026e-5 7.529948270127977e-5 2.1646780438068882e-5 -0.00012082170651815217 1.4741485570602932e-5 -7.164340714046059e-5 -4.38013441493564e-5 7.330433508793558e-5 2.4376459619427206e-5 -0.0001546777341031201 7.669686674086847e-5 -2.2429307631912664e-6 0.00012762892656207795 8.609172145549327e-5 0.00014439578880731487 8.611413973601727e-6 -7.058258874560101e-6; -3.1814129535130436e-5 0.00014757586418323232 7.306928696248137e-5 -0.00010164137834312679 -3.156083889866641e-5 -0.00010503667485995732 8.369232419258543e-7 -2.4048001570417947e-5 -6.107974090941755e-5 4.044104812105773e-5 8.045689945493234e-6 4.778319875964779e-5 -5.1462383772967774e-5 0.00010208169501179374 -5.2017208282859615e-5 -0.0002592172071230838 -7.448213487027743e-5 0.00012395532635178043 -0.00016794434973366255 0.0001672599634262512 -3.238387339611084e-5 9.544823355877633e-5 -1.837447539533032e-5 2.927028676690564e-5 0.00012364475937697502 -2.3715144699461976e-5 -1.613770414355734e-5 3.658434675212457e-5 -4.478055539170532e-5 -3.9881558542612856e-5 -9.483047811176898e-5 1.4365667674203449e-5; -0.0001461228340844728 -6.171431155914135e-5 -1.758393878795381e-5 -8.36894934395435e-5 -7.007712989385195e-5 0.00013104582179998295 3.438896314086097e-5 0.00011813181194205998 8.91672724178922e-5 -8.554448103914385e-6 6.71872083699256e-5 0.00015744781270815897 4.3217304123945685e-6 5.1788491728568054e-5 0.00016587235299133951 -9.583965871520992e-5 7.798583541493201e-5 1.929393587537595e-5 1.0474327600871315e-7 -4.72890489249586e-5 -0.00014452484461749719 -0.00010965569086644239 -1.1957287060302282e-5 0.00010883361832433494 -9.232807056783417e-5 0.00010395349530930678 -2.4128560798505513e-5 0.00025712103964955753 7.337346764010188e-5 -0.00011676748832495557 -1.332689058423205e-5 -4.7603272068464774e-5; 2.0019086769148845e-5 -0.00016852716664276298 8.559270343551705e-5 1.4415856349056386e-5 -0.00010365041391184146 3.281967354281186e-5 -9.065429487224101e-5 0.00015115696140140616 -0.00010012214926961491 8.383839001920517e-5 0.00017897757486947226 2.941685550481028e-5 -0.00011427131841616312 0.00033218918968499063 8.480989770772244e-5 -0.000164721185974364 -4.255719552008134e-5 3.009876007139171e-5 -0.00022038832987151088 -0.00026990538328378955 -5.8945918972190776e-5 6.0022590951245724e-5 5.984573943746209e-8 -0.00010348724328638611 0.00017839186028152998 -8.352750894106518e-5 2.6006439548049394e-5 -0.00010286803746554673 -2.3423440616046503e-5 -2.257030537159169e-6 0.00011921587127302344 5.300909168012008e-5; 2.3354412028149284e-5 3.605179661899542e-5 8.677996955597163e-5 0.00021190252444723293 1.5697488911271733e-5 -4.6818439309539604e-5 -3.4919301847535935e-5 2.760233981901878e-5 0.00011835746553045792 3.2890680435980244e-5 2.2853547838925608e-5 -3.103643613508482e-5 6.463656753551673e-5 3.4061163737380647e-5 7.818004959047962e-5 -2.6062429619889674e-5 8.002037937770843e-5 8.230561214516188e-5 0.00015288385619422224 4.193338607803999e-5 -0.00017659596634545435 0.00019751485668603993 -0.00013345867629611021 -0.00015531729733554042 0.00010011597767955918 8.3575113953728e-5 9.32218186652215e-5 -1.163814833049878e-5 0.00013034199363102614 6.565026850214619e-5 4.756368215095761e-5 0.00013483150686153766; -3.917527375174727e-5 7.701993116498265e-6 -0.00011857717710095113 4.4358073514165353e-7 3.311871022437806e-5 0.00016941365810330323 1.4275733509424756e-5 2.2933393655557237e-5 -4.5712048642008524e-5 3.7353219330404275e-5 9.60747556643119e-5 0.00018051845744469265 -1.3075426987918121e-5 0.00014160570480148513 7.961262507583323e-5 -6.401368683450522e-5 -0.00011726206231410292 -1.53740047711405e-5 7.408111467942531e-5 -2.4570256403905383e-6 9.14338396673797e-6 -0.0001401382041980076 -2.5555994408342493e-5 -8.080018688938407e-5 -0.0001407226673212335 -2.1423028566780334e-5 0.0001275755484936324 -2.4588368443210765e-5 -0.0001465488176212192 0.00017502628815111278 0.0002016813805208998 -1.7484534520330698e-5; -0.00013108665934987636 -3.368737288974102e-5 -3.357053920035821e-5 5.7294515316698104e-5 -2.1677255246233308e-5 0.0001259670878412148 -8.1396321731062e-5 -0.00023048223574869862 -0.00011592657592837292 3.617371275913578e-5 0.00010738825021553294 -5.914385624102062e-7 8.57097238485759e-6 -2.2859936690567954e-5 -8.106297356338093e-6 1.057147237451882e-5 -6.757116081192825e-6 -0.0002227865572435818 0.0001538118283847396 6.2813543807805e-5 0.00012839728678818303 -0.0001395982491899288 2.6727256953106057e-5 5.130626759500737e-5 -5.3189453139115195e-5 2.509032172836812e-6 5.977393611692277e-5 -8.571757116579029e-5 3.144746169573421e-5 2.2378517131545903e-5 3.642531173545641e-5 -1.3726075123031655e-5; 0.00014588312556995016 0.00013772542375632056 -0.00012260617144376894 8.152978340191632e-5 -2.3680576432096988e-5 0.00011664543204504023 -2.1998254978406256e-5 7.200751211755827e-5 -9.021046493173662e-5 5.770917492597503e-5 -0.0001211174013765456 9.133615900319666e-5 2.51195896093777e-5 5.950472847005047e-5 8.153663735399937e-5 0.00018217763941763952 -9.383672217111081e-5 -5.548517164724842e-6 -5.403094031822815e-5 0.0001320919262620636 -0.00015828511017346031 0.0001933901229317223 -8.14943770269792e-5 0.0001491653100497143 0.0001962301474672841 0.00017246467256015519 -1.3815664842261526e-5 -6.874577151349886e-5 0.00012616995160029927 -0.00011581915818133238 -1.855977565313427e-5 -7.828702577129179e-5; 5.89100170897998e-5 -0.0001282274963316823 -3.824006141509845e-6 0.00013244849904614728 4.3895063556847e-5 -4.5903986772170815e-5 -5.141296352784112e-5 -0.00010760669758135711 -1.217112114070708e-5 0.0002472890853666457 -4.0658483355653804e-5 0.00017909672968261775 -8.520144744834184e-6 -0.00012036035260942192 -6.756574950238918e-5 8.51627271756324e-5 -1.5289687523615783e-5 -0.00010638086607039488 0.00010752010783353807 -1.464371709269574e-5 9.632903018675429e-5 -4.0499354524659554e-5 4.904673986195267e-5 -2.987466106540455e-5 5.1833049640411434e-5 2.2194523610523763e-5 -7.754682812199221e-5 7.291165891976742e-5 0.00013399721572815608 9.688335629351649e-5 -5.5798139526162166e-5 -7.380018102851655e-6; -7.935193997999178e-6 -0.00011447240837647344 6.11367334435197e-6 5.792450968378192e-5 -1.0977622031950296e-5 -3.716549696897396e-5 3.61569072361296e-5 4.781675300413511e-5 8.500805409361378e-5 -9.028860228632665e-5 -4.849930707629046e-5 2.619261788873027e-5 -7.002261560906805e-5 -1.777563199490029e-5 -3.0991752345890396e-5 0.0001974361002162725 -3.783504877852529e-5 -2.9610084374743743e-5 1.8810196643742127e-5 5.075460278980341e-5 0.00014773346622909213 5.6212866120898204e-5 -0.00014504294819695147 -8.71522498250304e-5 8.926769071926104e-5 -0.00012721552781791556 -2.4770475710702735e-5 -1.9216284335434768e-5 4.3706040678936956e-5 -0.00014794487657986277 9.887072513314877e-5 -0.00011363656819363149; -7.328599265336218e-5 0.00011838787661374003 7.86828886016566e-5 -6.538629629674547e-5 3.261032809361774e-5 -5.072912342755015e-5 -3.84675304990012e-5 -3.4808826126641545e-5 9.431426490166325e-5 -7.293198821160114e-5 -7.856870895509542e-5 6.249666232589022e-5 -9.649779628987442e-5 0.00013777292672286778 -0.00011004509822257675 5.1350470061629514e-5 -7.745025602660013e-5 0.00012356960262379548 -0.00023877014707567149 -0.00015943581190683975 7.966881723816571e-5 -2.7063398261343888e-5 5.326150032899402e-5 9.787392530099502e-5 -1.9375164879852454e-5 0.00011191950663899032 -0.00012611815425359526 2.2216196998237525e-5 8.092084430719431e-6 7.34544509437022e-5 0.00025836013712662084 0.00014288240974976498; -0.0001351716260392346 -6.602451919709483e-5 -7.990809406285663e-5 0.00012663246291092265 1.0556903492816463e-8 0.00011151890923835797 3.2980236414435384e-5 2.235077902193646e-5 8.459523216922301e-5 8.750740038951434e-6 -1.547160588918704e-5 -0.0001119408176014259 1.7938763834965046e-5 -0.00014748614064549536 0.00010220275128128463 5.852223096050768e-5 -1.984918942573962e-5 -4.393508659225346e-5 0.00015320555523837945 -1.7366316183622913e-5 1.4679312175277315e-5 0.00011391546415732002 7.120755019465944e-5 1.8744402036386428e-6 -0.00011688900719994234 -2.2472132127885036e-5 -5.969879437136942e-5 -1.573794777462252e-5 0.00010888023227787265 -0.0002786926693948026 6.250846512987326e-6 -4.0113880982541936e-5], bias = [1.7961731209854162e-9, -1.1688129315307571e-9, -9.773901680813207e-10, 2.2317574950537164e-11, -5.444722067917934e-9, -1.9288704212601844e-9, 2.1912867173200623e-9, 3.424580654262725e-9, 1.353405715126613e-9, -3.246144254867433e-9, -3.857906690657621e-9, 3.4381096772187434e-9, -1.6009703539052454e-9, -9.583286851564076e-13, -2.465598342535315e-9, -1.7517557176228536e-9, -3.4754354741580563e-10, -2.166950951360591e-10, 5.957375827257967e-10, 5.642078065593702e-10, 2.2126146862426114e-9, -9.143618176297331e-10, 1.5032497171203748e-9, -7.087872893509728e-10, 5.567548976899038e-9, 1.4251678751166307e-9, -1.1443276015846582e-9, 4.060766949321969e-9, 1.8124253285020975e-9, -8.222463846834475e-10, 1.1267952915190247e-9, -5.589531876827552e-10]), layer_4 = (weight = [-0.0007050179490613795 -0.0004920573102096551 -0.0005536344310431565 -0.0006963821146971237 -0.0006593472889796238 -0.0006523439937773694 -0.0007408739416703592 -0.000755089413970887 -0.0006945990225966604 -0.0009152151640364027 -0.0006443450217995234 -0.0006971588265510892 -0.0007945894088549281 -0.0007317844238021631 -0.0008512115555582115 -0.0004974773857445845 -0.0006176108584618909 -0.0007019271734870736 -0.0007112881600718082 -0.0006877473375983264 -0.0007362892563509111 -0.000657459808557951 -0.0007390653613356624 -0.0005830805529843867 -0.0006775036209883346 -0.0008544847020083854 -0.0007603814490872107 -0.0004756124223530813 -0.0007177428187299262 -0.0006988079278728038 -0.0007884145426630212 -0.0007008952936738197; 0.00020726024914500163 0.00018378483111754996 0.00014467837320017408 7.08587690374954e-5 0.00025217062008363225 0.0003449413308827007 0.00010630030363585429 9.953493892481873e-5 0.0003010931781629568 0.0004028239492081044 0.0002928969482291252 0.0001863185924666615 0.0002732050927064584 0.00011776133077577541 0.00011050089263328886 0.00015698714087173296 0.00022791349292202625 0.00019629299288992903 0.000418058310709979 0.0003617885577443187 0.00047557731304555853 0.00018800524616039061 0.0002839894831757226 0.0002735159476191791 0.00021503511590495401 0.00032489191650729394 0.00014718636375011607 0.00027461047580199313 0.0003138825841641192 0.0001243358086855583 0.00022999763360557278 0.00015335620821860143], bias = [-0.0007017932612105686, 0.0002357911075617672]))
Visualizing the Results
Let us now plot the loss over time
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")
lines!(ax, losses; linewidth=4, alpha=0.75)
scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)
fig
end
Finally let us visualize the results
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(compute_waveform(
dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)
l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
s3 = scatter!(ax, tsteps, waveform_nn_trained; marker=:circle,
alpha=0.5, strokewidth=2, markersize=12)
axislegend(ax, [[l1, s1], [l2, s2], [l3, s3]],
["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
position=:lb)
fig
end
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.10.6
Commit 67dffc4a8ae (2024-10-28 12:23 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 128 × AMD EPYC 7502 32-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
JULIA_CPU_THREADS = 16
JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 16
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.