Training a Neural ODE to Model Gravitational Waveforms
This code is adapted from Astroinformatics/ScientificMachineLearning
The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl
Package Imports
using Lux,
ComponentArrays,
LineSearches,
OrdinaryDiffEqLowOrderRK,
Optimization,
OptimizationOptimJL,
Printf,
Random,
SciMLSensitivity
using CairoMakie
Define some Utility Functions
Tip
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector
function one2two(path, m₁, m₂)
M = m₁ + m₂
r₁ = m₂ / M .* path
r₂ = -m₁ / M .* path
return r₁, r₂
end
Next we define a function to perform the change of variables:
@views function soln2orbit(soln, model_params=nothing)
@assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
if size(soln, 1) == 2
χ = soln[1, :]
ϕ = soln[2, :]
@assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
p, M, e = model_params
else
χ = soln[1, :]
ϕ = soln[2, :]
p = soln[3, :]
e = soln[4, :]
end
r = p ./ (1 .+ e .* cos.(χ))
x = r .* cos.(ϕ)
y = r .* sin.(ϕ)
orbit = vcat(x', y')
return orbit
end
This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d_dt(v::AbstractVector, dt)
a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
b = (v[3:end] .- v[1:(end - 2)]) / 2
c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
return [a; b; c] / dt
end
This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d2_dt2(v::AbstractVector, dt)
a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
return [a; b; c] / (dt^2)
end
Now we define a function to compute the trace-free moment tensor from the orbit
function orbit2tensor(orbit, component, mass=1)
x = orbit[1, :]
y = orbit[2, :]
Ixx = x .^ 2
Iyy = y .^ 2
Ixy = x .* y
trace = Ixx .+ Iyy
if component[1] == 1 && component[2] == 1
tmp = Ixx .- trace ./ 3
elseif component[1] == 2 && component[2] == 2
tmp = Iyy .- trace ./ 3
else
tmp = Ixy
end
return mass .* tmp
end
function h_22_quadrupole_components(dt, orbit, component, mass=1)
mtensor = orbit2tensor(orbit, component, mass)
mtensor_ddot = d2_dt2(mtensor, dt)
return 2 * mtensor_ddot
end
function h_22_quadrupole(dt, orbit, mass=1)
h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
return h11, h12, h22
end
function h_22_strain_one_body(dt::T, orbit) where {T}
h11, h12, h22 = h_22_quadrupole(dt, orbit)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
h11 = h11_1 + h11_2
h12 = h12_1 + h12_2
h22 = h22_1 + h22_2
return h11, h12, h22
end
function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
# compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2
@assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"
h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
@assert mass_ratio ≤ 1 "mass_ratio must be <= 1"
@assert mass_ratio ≥ 0 "mass_ratio must be non-negative"
orbit = soln2orbit(soln, model_params)
if mass_ratio > 0
m₂ = inv(T(1) + mass_ratio)
m₁ = mass_ratio * m₂
orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
return waveform
end
Simulating the True Model
RelativisticOrbitModel
defines system of odes which describes motion of point like particle in Schwarzschild background, uses
where,
function RelativisticOrbitModel(u, (p, M, e), t)
χ, ϕ = u
numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
denom = sqrt((p - 2)^2 - 4 * e^2)
χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = numer / (M * (p^(3 / 2)) * denom)
return [χ̇, ϕ̇]
end
mass_ratio = 0.0 # test particle
u0 = Float64[π, 0.0] # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4) # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize) # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e
Let's simulate the true model and plot the results using OrdinaryDiffEq.jl
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)
axislegend(ax, [[l, s]], ["Waveform Data"])
fig
end
Defining a Neural Network Model
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model
that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but in case you do use them, make sure to mark them as const
.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl
,
const nn = Chain(
Base.Fix1(fast_activation, cos),
Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-0.00015090362; -8.463823f-6; 0.00020540055; -0.00015700595; -4.4267104f-5; -4.7773992f-5; 0.00023573925; -2.7327524f-5; -0.000114716786; -3.9591305f-6; -0.00014696848; -0.00019809845; 8.3716266f-5; -5.340119f-5; 0.00017846288; -0.00011863013; 7.2552204f-5; 3.5595345f-5; -5.0283947f-5; -4.9698083f-5; -1.3690377f-5; -0.00024830314; 0.00012856653; 6.1756524f-5; 8.470422f-6; 0.00016430899; -0.00012057677; -2.974124f-5; -0.00014401098; -8.6338165f-5; 0.00016550989; -8.55622f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-2.1165648f-5 5.0500952f-5 4.7947782f-5 7.256527f-5 -9.463453f-5 6.559931f-5 -1.3545024f-6 4.4081713f-5 4.7392732f-5 -2.7824513f-5 3.1382115f-5 -1.0304528f-5 5.6110053f-5 0.00010574072 3.4803285f-5 -1.8547906f-5 0.000130308 0.00031772285 8.095491f-5 3.2271586f-5 -6.830569f-5 2.0922048f-6 8.473368f-5 7.6166034f-5 7.9883175f-5 2.4299761f-5 -0.00014244307 -8.437712f-5 0.000121472876 -0.00028370868 -7.024228f-5 -0.00016436468; -1.9498599f-5 0.0002384625 0.00015677931 -2.0874806f-5 8.188062f-5 -7.3008654f-5 -4.4211687f-5 -0.00021810514 -0.00019904773 0.0001718232 -0.00015571687 8.389914f-5 -3.771425f-6 4.827022f-5 -8.6851964f-5 4.1023835f-5 -9.827163f-5 -2.9761604f-5 -0.00011055285 -2.8316217f-5 -6.8101217f-6 6.973589f-5 -0.00014166758 3.295568f-5 0.0001123189 1.4763515f-5 -7.04633f-5 -3.643886f-5 -0.00013884337 -0.0001001675 2.8880702f-5 -0.00017509604; -0.00020911284 -7.536324f-5 -6.166835f-5 -0.00013322335 -3.7136855f-5 3.8413757f-5 -3.6381025f-5 0.00020342405 0.000117541014 3.8101756f-5 -0.00014812972 7.0241105f-5 -9.1356516f-5 -2.3749715f-5 7.457846f-5 8.435499f-6 4.247725f-5 5.5401975f-5 9.694191f-6 6.3040993f-6 -8.088973f-5 3.6150243f-6 -6.828356f-5 1.6875539f-6 -9.78848f-5 8.378348f-5 4.8198053f-5 -0.00010719506 0.0001749121 -7.129386f-5 -0.00015710526 0.00018999833; -6.824945f-5 7.176751f-5 1.2675895f-5 -7.849781f-5 -0.00039939955 -8.810742f-5 -3.0614025f-5 0.00011582179 -9.029021f-6 4.981639f-5 8.993682f-5 -1.2523145f-6 2.5951993f-5 -0.000118193755 -5.4783595f-6 9.792813f-5 -0.00012660844 -0.00012810843 -0.00014531413 -0.00010627196 0.00012563007 -0.00032760715 9.188287f-5 -0.000107824286 -1.6014634f-5 -0.00010602593 4.5965484f-5 0.00015848411 -1.563245f-5 3.4812267f-5 9.781416f-5 3.4600605f-5; -1.2476194f-5 9.694459f-5 -0.0001190714 0.00011464835 3.1550768f-5 -8.2673585f-5 0.00017047249 1.9410103f-5 7.10154f-5 0.00013905355 1.5802454f-5 6.435535f-5 -2.52826f-6 8.4621555f-5 6.236057f-5 8.5720574f-5 -4.5355497f-5 -6.184095f-5 0.00011810701 -6.254415f-5 -5.5699482f-5 -3.106806f-5 0.0002059834 5.3599597f-6 4.5973182f-5 -2.343823f-5 -1.7022418f-5 -5.6002114f-5 0.000115259725 -0.0002209328 0.0001311833 -4.7988856f-6; 7.6252065f-5 7.847761f-5 -8.281996f-5 0.00014601443 0.00011454118 -8.0187485f-5 -3.6581716f-6 -2.6047546f-5 -5.461305f-5 4.9749146f-5 1.1885932f-5 0.00011439276 -0.00020970427 -0.00012041472 -0.00013609363 -2.4483385f-5 -0.0002253949 -5.973676f-5 4.1347193f-5 1.5695641f-5 4.4592805f-5 -4.061531f-5 1.0062132f-5 5.960993f-5 3.6369496f-5 8.42075f-5 -3.147741f-5 1.5265743f-5 -1.1776208f-6 -7.012644f-5 0.00026326958 5.4683256f-5; -0.00023150427 -4.4311102f-5 7.2963194f-5 -2.8592107f-5 -6.419032f-5 0.0001774565 1.8647734f-5 -6.7288083f-6 -6.4846005f-5 9.500336f-5 -8.6542896f-5 4.60881f-6 0.00010625601 -8.559041f-6 0.00020389163 -0.00011082882 4.125422f-5 0.00019953634 -1.2586349f-6 6.7798646f-5 -0.0002053087 -8.746093f-5 0.0001462021 -0.0002451178 -0.00014362352 7.610798f-5 5.62084f-5 -2.9942892f-5 -2.9659477f-6 2.5961268f-5 -8.221114f-5 7.732975f-5; 0.00016248746 -4.4453205f-5 -9.275598f-5 6.967f-5 0.00015391322 -4.1240786f-5 4.9220074f-5 0.00012305066 -7.626098f-5 -2.860815f-5 -0.000110672445 -8.3545834f-5 -4.7215403f-6 7.5050455f-5 1.4005798f-5 -9.519128f-5 8.232137f-5 8.66922f-5 -8.851084f-6 -6.2590305f-5 0.00012959143 -0.00011558106 5.4535195f-7 -6.849799f-5 -4.373461f-5 -0.0001723428 6.968158f-5 -0.00011355082 -0.00010883078 -2.8716082f-5 -4.0192932f-5 3.1115524f-5; -4.7076413f-5 0.00012503553 -0.00010856605 3.4132598f-5 -0.00011791155 -6.292576f-5 -0.00021242686 -4.84029f-5 -5.207472f-5 -0.00015483044 -1.9088198f-5 1.2119367f-5 2.5169458f-5 4.8388367f-5 -6.8502864f-5 6.836119f-5 -0.00013620267 -0.00010441406 -7.537726f-5 9.1539776f-5 2.3982791f-5 -3.3601704f-5 4.4184395f-5 -0.000114855626 -7.843204f-5 -8.5113046f-5 4.5944216f-6 -8.975312f-5 3.6984176f-5 -3.0671148f-5 -2.0482472f-5 8.827059f-5; 0.00017643016 2.3205192f-5 -5.7305704f-5 -4.1985557f-5 -0.00010486742 0.00022033461 -0.00011402413 -2.8096816f-5 0.000122242 7.42707f-5 0.00014912353 -4.485481f-5 -7.0858434f-5 3.1245116f-5 -6.810411f-5 0.00012097205 0.00014908245 1.6642936f-5 5.278673f-5 7.511637f-5 0.0001307365 1.8963343f-5 -0.0001371775 4.8518898f-5 8.341599f-5 8.70275f-7 -9.713271f-5 -5.7679932f-5 6.191689f-5 0.00014351138 5.5135522f-5 -3.4488323f-5; -6.4956424f-5 6.845625f-5 2.4336203f-5 -0.00016200557 6.66831f-5 0.00011695535 0.00021468807 2.6414253f-5 -8.992955f-5 9.45933f-5 6.703183f-5 9.976361f-5 -0.00014852098 6.6697095f-5 -1.0355417f-5 -2.0534502f-5 -3.9912433f-5 -1.7285465f-5 0.00013311848 0.00012858218 0.00011164309 -0.00013240002 0.00015393294 6.452993f-5 5.0880863f-5 -3.2247847f-6 -0.00012428727 0.00010198965 -0.00013944262 8.697376f-5 -0.00010742211 3.0341993f-5; -7.8687284f-5 0.00015028851 0.00013419442 1.8898878f-5 -6.647389f-5 3.5678226f-5 -3.147098f-5 4.9226295f-5 -5.6841465f-5 9.868224f-5 -0.00020508574 7.496335f-5 0.000102686834 -6.4575004f-5 2.308184f-5 4.441304f-5 4.6647856f-5 -7.757841f-5 -0.00013522587 -1.8173934f-5 0.000105051324 -0.00023647283 0.00014404589 6.667556f-5 -2.317227f-5 2.883814f-6 -3.4245902f-5 -0.00022863159 -3.9888924f-5 -3.4146956f-6 0.00010923472 -0.00010998123; 3.4169832f-5 3.08009f-5 -1.6098798f-5 5.7419707f-5 -2.6387534f-5 4.966144f-5 -9.170078f-5 3.4810066f-5 -0.00010099423 -0.00015306841 1.2887367f-5 -4.1907733f-5 0.00017781317 -2.8676117f-5 -5.1454037f-5 -3.3803124f-6 -0.0001782519 4.311269f-6 0.000112365815 1.4740496f-5 2.7109594f-5 -0.000102397426 8.091007f-5 0.00017430856 1.4523436f-6 -1.0458106f-5 -9.893838f-5 -7.0363916f-5 8.246475f-5 -0.00016958895 -9.585297f-5 -7.262079f-5; -0.00015964957 4.8577567f-5 -5.487405f-5 -0.0001445053 -4.211038f-5 -0.00012539276 -4.6623227f-5 9.07253f-5 4.2765027f-5 -0.00013549191 -4.6631487f-7 0.00017217438 1.7190947f-5 0.00014467581 8.7198096f-5 9.383959f-6 -6.0501325f-6 0.0001508446 6.241433f-5 2.4024892f-5 3.6057452f-5 -4.1752442f-5 8.102529f-6 -1.6674658f-5 3.343456f-5 -6.749138f-5 9.771836f-5 -1.5492298f-6 -8.478065f-5 -4.394246f-5 -9.06465f-5 0.00015966818; -7.518042f-5 0.00021259996 -8.823414f-5 4.3982596f-5 7.477998f-5 0.00010692436 7.3815618f-6 4.4274308f-5 -7.993137f-5 -1.8704854f-5 4.1245694f-6 1.7104796f-5 -0.00012053361 -9.1849004f-5 0.00010675494 -0.0001355293 -1.3328679f-5 6.6334087f-6 5.03943f-5 -0.00014101874 -7.29545f-6 9.914104f-5 -1.1474056f-5 -0.00018028563 0.00013944635 -4.1016767f-5 6.5955405f-6 -5.1904408f-5 -5.9763257f-5 3.4085202f-5 2.9388233f-5 3.9559163f-5; 8.096887f-5 0.00015928409 -1.1266973f-5 -7.9541576f-5 -0.00014184462 -5.0296727f-5 -9.8038065f-5 -1.6104643f-6 8.618698f-5 7.2509334f-5 1.558933f-5 -9.378192f-5 0.00018772883 0.00013281488 8.9208035f-5 -0.0001667347 9.291682f-5 -7.7662145f-5 -6.067225f-5 0.00021219777 -4.4655917f-5 -5.2664684f-5 -3.131379f-5 1.899719f-5 -1.9283283f-5 -7.62929f-5 5.5354252f-5 6.445478f-5 0.0001529196 -6.040016f-5 7.296937f-5 -7.1608124f-6; -3.915612f-5 -2.5472156f-5 -3.0540614f-5 3.83984f-5 -4.1990206f-5 -1.1827236f-5 4.9128084f-5 -1.1185228f-5 -3.8767314f-5 -2.1421924f-6 -0.00010326394 1.8307672f-6 9.062641f-5 0.00014975666 -3.889486f-5 4.721779f-6 -4.1758252f-5 5.7742f-5 5.0515362f-5 -2.4326582f-5 -0.00012401248 -1.7705523f-5 -8.95716f-5 -3.0053008f-5 5.277605f-5 -9.996833f-5 2.4765603f-5 4.783959f-5 -5.025109f-5 8.4388965f-5 2.0156467f-7 -1.6331158f-5; 1.6624621f-5 0.00012956433 -2.0704f-5 -5.4022206f-5 6.746865f-5 -0.000105864754 0.00014648687 0.00010553307 -5.2909734f-5 0.00013290775 -0.00014006575 -7.293764f-5 1.4726692f-5 -0.00010275925 -4.59697f-5 4.8378963f-5 -9.491042f-5 -8.691381f-5 -4.357437f-5 0.0001256736 2.6804777f-5 -6.285282f-5 -0.0001506739 -0.0001336679 5.6545156f-5 5.1112962f-5 -7.029663f-5 5.715197f-5 0.000181966 -2.0401505f-5 5.777695f-5 -7.2891824f-5; 1.8041377f-5 -8.147996f-6 3.6691883f-5 -0.0001538761 4.0217845f-5 -0.000109499786 -2.4820192f-6 -5.4667147f-5 3.245275f-5 1.28862075f-5 3.628228f-5 -0.00010110825 0.00014104182 0.00012282091 4.1226293f-5 3.2769283f-5 6.913994f-6 -5.9135677f-5 0.00011294312 -7.6351374f-5 1.4028988f-5 5.4524726f-6 -2.8674403f-5 0.00012895289 5.126447f-5 -0.000116158895 0.00014292767 0.00011595494 -5.8143636f-5 -7.700508f-5 0.000172921 -0.00022471763; 5.5989083f-5 2.4201165f-6 2.9710616f-5 0.00016417308 1.6055588f-5 -4.0894698f-5 -0.0001660548 0.00017448509 -0.000117404954 -2.6518697f-5 1.98422f-5 3.9237348f-5 -3.6965954f-5 0.00015486458 4.7404523f-5 7.68102f-5 8.1417944f-5 3.186909f-5 -6.458914f-5 6.8920695f-6 -0.00014623796 0.00014359625 -4.074608f-5 -4.4789318f-5 -5.7927014f-6 3.872209f-5 -6.144943f-5 -4.6667697f-5 -4.271697f-5 1.05700165f-5 -0.00012458996 4.5511646f-5; 0.00013365946 -0.00019328561 -6.316584f-5 -6.689396f-5 -0.0001795808 5.155989f-5 1.9922052f-5 -4.8122856f-5 7.693376f-6 2.6320662f-5 -2.5604611f-5 -8.283266f-5 -6.541855f-5 -7.9630016f-5 3.5686968f-5 -8.575746f-6 -8.546905f-5 2.2313256f-5 -0.00012303692 -2.8425706f-5 -0.0002038851 0.00011940929 4.3230015f-5 -5.0992396f-5 -9.2215574f-5 -0.0001006019 -4.001055f-5 3.4672074f-5 1.1325736f-5 -1.0074009f-5 1.15132806f-7 1.9010058f-5; 8.009491f-5 8.859885f-5 9.733008f-5 -5.5736222f-5 0.00012154643 1.7487382f-5 -2.8436967f-5 -1.0671288f-5 0.00012157564 -8.417737f-5 -8.699047f-6 -7.055391f-5 0.000105306935 0.00012774575 8.736735f-5 -1.9672781f-5 -0.00015637049 -9.217274f-5 0.000120779645 -0.00010406518 -3.668976f-5 -0.00010721674 6.959834f-5 2.1090449f-5 -9.943622f-5 -2.7680344f-5 0.00010118385 0.00020617584 4.4101176f-5 -0.00016865844 1.3938555f-5 -8.569023f-5; -1.976274f-5 0.00016394234 -8.680111f-5 -0.00014127154 -7.284679f-6 -2.3877568f-5 4.331517f-5 0.00013790275 6.636905f-5 7.295048f-5 -2.7635326f-5 -9.839307f-5 -6.41755f-5 -4.6028577f-5 -4.437729f-5 -1.0973883f-5 0.00018886518 -6.7572005f-6 -0.00012948559 -0.00010354269 0.00010292204 -5.8536676f-5 -8.075831f-5 5.822072f-5 -0.00017127384 -4.807289f-5 -1.1527733f-5 0.00019305263 6.37052f-5 -1.6944424f-6 -0.00020566987 -4.00872f-5; -0.00016451524 -6.207382f-5 3.3350761f-6 0.000119850934 0.00012435162 -5.3857697f-5 5.9947484f-5 4.5428405f-5 -0.00011291061 9.706844f-6 -3.5531823f-5 6.433628f-5 -3.139331f-5 -2.5661922f-5 -4.819812f-5 0.00018470979 -4.0358005f-5 -0.00023431667 6.7169814f-5 4.983102f-5 -0.00018061587 -6.8861475f-5 -3.0310955f-5 -0.00019008631 4.06062f-5 8.149342f-5 0.00017372139 -0.000111034286 2.342074f-6 5.622031f-5 1.9284242f-5 4.1040574f-5; -3.547374f-6 0.00014384532 9.07116f-6 0.0002145477 -1.0891746f-8 -0.00014184573 -2.3223272f-5 7.868545f-5 -8.042678f-5 -1.8246714f-5 4.995605f-5 -3.9673516f-5 1.6247643f-5 8.331116f-6 5.4973036f-5 -3.4404104f-5 1.427133f-7 -6.082243f-5 -0.00017639065 2.1534628f-5 -5.736868f-5 0.00015263002 -3.4185992f-5 -8.132145f-5 -4.1376832f-5 -0.00011322391 -0.00016457957 6.054473f-6 3.44907f-5 6.1890314f-5 -8.0112295f-5 -1.7204533f-5; 0.0001125962 -5.576989f-5 7.8566336f-5 6.935268f-5 -7.722156f-5 7.5860306f-5 0.0001043249 -0.00020738013 -7.051083f-5 -6.424485f-5 5.5459448f-5 -1.52584425f-5 0.00017225547 4.0163308f-5 -3.149853f-5 0.0001864849 9.878691f-5 8.1566985f-5 0.00026410085 -1.3374854f-5 -5.872415f-5 0.0001421029 0.00011655776 0.00022888214 -0.00012241608 7.795073f-5 5.8515434f-5 -8.452494f-5 3.2024513f-5 0.00015500371 5.7134013f-5 -4.680298f-5; -0.00014854681 7.591369f-5 4.9790226f-5 -3.0244106f-5 -4.231071f-5 2.5181484f-5 0.00011698839 3.192006f-5 2.3411398f-5 -9.102787f-6 5.5223518f-5 -0.00015584743 9.667681f-5 8.616424f-5 5.5597156f-5 8.6109496f-5 0.00010408159 -7.6488686f-5 0.000105362604 -3.8339378f-5 -2.8873114f-6 1.8399864f-5 -8.279151f-5 -9.701829f-5 7.764604f-5 7.879665f-5 1.2832609f-5 -2.790893f-5 0.00013341618 5.2920703f-5 -7.7523495f-5 -2.5505418f-5; -0.00021366171 3.9475276f-6 -3.2217133f-5 0.00011181767 5.406752f-5 1.4846141f-5 -1.450177f-6 7.1318727f-6 -6.887335f-5 -4.904597f-5 -2.2252638f-5 -1.3013625f-5 -0.00019908288 -9.5764684f-5 1.8358915f-5 2.6269518f-5 -7.491253f-5 -1.1413631f-5 -5.3147673f-5 -3.3932127f-5 -0.000121958685 -0.00018620773 -0.0002268132 0.0001997613 -5.1649928f-5 7.5547854f-5 -0.00010477215 -0.00022782228 -4.573001f-5 -1.2988973f-5 7.514176f-5 0.00018934514; -0.00010067848 -3.9411284f-6 -0.0001059268 -2.7050935f-5 -2.4688685f-5 -8.358114f-5 0.00012999828 5.410952f-6 6.908646f-5 0.00011320491 0.00014282294 -7.956316f-5 -0.00016909902 1.0778803f-5 -2.0057008f-5 -3.8435002f-5 -9.504111f-5 -9.47948f-6 -4.707471f-5 0.00012479124 -8.3231906f-5 7.494093f-5 0.00021965355 -2.6914086f-5 -0.00012417957 -6.325471f-7 2.143948f-5 0.00014800661 2.8760205f-5 9.690896f-5 1.6894599f-5 -0.00012517214; -7.3907926f-5 4.018103f-6 -0.0002067274 -0.00016475574 -0.000116328614 -0.000112710506 0.0001338838 1.729044f-5 -6.8455054f-5 9.858215f-5 -3.1216863f-5 6.525146f-5 -0.00010678202 2.050075f-5 -2.6220294f-5 -8.556174f-5 -9.383551f-5 -2.0393572f-5 8.743258f-5 7.94915f-6 -0.0001002634 -0.00021169936 6.683955f-5 3.662552f-5 -2.3749455f-5 -8.982027f-5 -0.00019141244 4.7807465f-5 4.5858174f-5 5.4177588f-5 -0.00011780259 -9.1852584f-5; 3.8876733f-5 -3.0937495f-5 -9.406694f-5 6.940811f-5 -6.130588f-6 3.8223192f-5 6.9025664f-5 -0.000109138295 -2.0570424f-5 -1.9288991f-5 4.5164645f-5 -5.2538777f-5 9.736465f-5 5.242314f-5 1.3217264f-5 4.813359f-5 4.013479f-5 -1.3558623f-5 -5.024822f-5 5.6036755f-5 -3.6083096f-5 8.4877036f-5 -8.461897f-7 -4.2224885f-5 2.2246018f-5 -6.247578f-5 -0.000119525685 7.498407f-5 -4.2561525f-5 -2.2542248f-5 -0.00012716188 -7.779595f-5; -2.6750185f-5 6.291541f-5 4.982271f-5 -0.00015119324 -4.0083753f-5 0.00015823684 -5.2933694f-5 2.9595232f-5 -8.7943f-5 -0.00016993792 3.4041015f-5 5.2228515f-5 3.1199637f-5 4.748912f-5 -3.0561056f-5 -0.00011368892 4.229177f-5 -0.00025960477 -0.00012873634 0.00014615091 -1.7064667f-5 -0.00012042332 -7.097036f-5 0.00014953867 -1.0653301f-5 -4.8177175f-5 -7.6921606f-5 -2.9108469f-5 -0.00014570197 -2.0344289f-5 0.00013480739 8.860958f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[0.00010827614 -2.5591968f-5 -1.1048545f-5 9.019165f-6 5.0362658f-5 1.6079077f-5 3.6900917f-6 -2.9768978f-5 5.5402325f-6 4.111949f-6 -4.537496f-7 -0.00017634103 -0.0001226249 0.0001538932 -0.0002463556 4.3353208f-5 5.6761935f-5 0.0001345399 -0.00010856281 -0.0001288653 -6.623624f-5 -9.63302f-5 -9.353219f-5 1.0995333f-5 7.103173f-5 -3.9238057f-5 0.00012379397 5.5275297f-5 2.6383666f-5 0.00013332238 -7.480599f-5 3.7480862f-5; -0.00015978346 -0.00015526234 -4.1720483f-5 -6.143811f-5 -0.00010680786 -4.364673f-6 2.8432687f-5 2.339937f-6 0.00022203944 5.974749f-5 0.0001632817 6.966566f-5 -6.112847f-5 -3.4048346f-5 -9.7721204f-5 -8.381741f-5 -0.00014006408 -5.3600743f-5 -7.5140444f-5 3.0462965f-5 -0.00015062978 -0.00014137666 -3.3266897f-6 1.5325408f-5 -1.8683826f-5 -0.00012379457 0.00014326724 0.00023545808 -7.0582704f-5 1.2926778f-5 -0.00012915785 -0.00010587142], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))
Similar to most deep learning frameworks, Lux defaults to using Float32
. However, in this case we need Float64
const params = ComponentArray(f64(ps))
const nn_model = StatefulLuxLayer(nn, nothing, st)
StatefulLuxLayer{Val{true}()}(
Chain(
layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
layer_2 = Dense(1 => 32, cos), # 64 parameters
layer_3 = Dense(32 => 32, cos), # 1_056 parameters
layer_4 = Dense(32 => 2), # 66 parameters
),
) # Total: 1_186 parameters,
# plus 0 states.
Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses
where,
function ODE_model(u, nn_params, t)
χ, ϕ = u
p, M, e = ode_model_params
# In this example we know that `st` is am empty NamedTuple hence we can safely ignore
# it, however, in general, we should use `st` to store the state of the neural network.
y = 1 .+ nn_model([first(u)], nn_params)
numer = (1 + e * cos(χ))^2
denom = M * (p^(3 / 2))
χ̇ = (numer / denom) * y[1]
ϕ̇ = (numer / denom) * y[2]
return [χ̇, ϕ̇]
end
Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
axislegend(
ax,
[[l1, s1], [l2, s2]],
["Waveform Data", "Waveform Neural Net (Untrained)"];
position=:lb,
)
fig
end
Setting Up for Training the Neural Network
Next, we define the objective (loss) function to be minimized when training the neural differential equations.
const mseloss = MSELoss()
function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
return mseloss(pred_waveform, waveform)
end
Warmup the loss function
loss(params)
0.0006947190204148482
Now let us define a callback function to store the loss over time
const losses = Float64[]
function callback(θ, l)
push!(losses, l)
@printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
return false
end
Training the Neural Network
Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
optprob,
BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
callback,
maxiters=1000,
)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-0.00015090362285239398; -8.463823178319395e-6; 0.0002054005453826323; -0.0001570059539516528; -4.426710438560331e-5; -4.777399226434966e-5; 0.0002357392513656482; -2.7327523639562606e-5; -0.00011471678590150692; -3.9591304812333405e-6; -0.0001469684793844431; -0.00019809845252899072; 8.371626608993363e-5; -5.3401188779306205e-5; 0.00017846288392311694; -0.00011863013060053942; 7.25522040737047e-5; 3.55953452526042e-5; -5.0283946620610274e-5; -4.969808287567535e-5; -1.3690377272711159e-5; -0.00024830314214285723; 0.00012856653484034854; 6.175652379164117e-5; 8.470421562371793e-6; 0.0001643089926799496; -0.00012057676940445217; -2.9741240723514597e-5; -0.0001440109772373567; -8.633816469225788e-5; 0.00016550988948428694; -8.556219836472145e-5;;], bias = [-2.793585799831911e-16, 2.2677545547519387e-18, 2.738257231758117e-16, -1.0416623590858451e-16, -5.824861459897287e-17, -6.936460262807558e-17, 2.717912450947749e-16, 2.4082085516947887e-17, -2.1336659467219778e-16, -1.5299135332283108e-18, -2.6042937926410695e-16, 5.897621861171444e-17, 1.34777125885288e-16, -9.662202774808112e-17, 9.894971218212377e-17, -1.2939010698814795e-16, 1.932312185984022e-16, 6.36810986157672e-17, -1.815048754421389e-16, -2.0745864704734232e-17, -1.4253311128243528e-17, -6.375576562833117e-16, 1.7724177453097424e-16, 7.078956695164118e-17, 7.831704912543081e-18, 2.4727144671401445e-16, -1.4671462454999861e-16, -1.007687603688814e-17, -1.5081623426596322e-16, -2.425676696888067e-17, 1.1383416949946175e-16, 9.082144830994143e-17]), layer_3 = (weight = [-2.116352594965056e-5 5.050307402929077e-5 4.7949904122656684e-5 7.256739366641221e-5 -9.463240973342556e-5 6.560143550235408e-5 -1.3523804149415483e-6 4.408383495832532e-5 4.739485405809991e-5 -2.7822391408665352e-5 3.138423665734171e-5 -1.0302405704708129e-5 5.61121752376785e-5 0.00010574284272810038 3.480540647943891e-5 -1.854578373820305e-5 0.00013031012816430433 0.0003177249719862052 8.095703118091063e-5 3.2273707923773605e-5 -6.83035713030793e-5 2.094326767829102e-6 8.473579976782717e-5 7.616815595032174e-5 7.988529717625833e-5 2.4301883180320967e-5 -0.00014244094460661808 -8.437499751959269e-5 0.00012147499831700923 -0.0002837065575171148 -7.024015473360434e-5 -0.00016436256088706445; -1.9501303494379986e-5 0.00023845979199097058 0.00015677660308134302 -2.087751085933425e-5 8.187791722877616e-5 -7.301135878958982e-5 -4.421439184518717e-5 -0.00021810784341949415 -0.00019905043695768552 0.00017182048945971326 -0.00015571957733196605 8.38964352180087e-5 -3.774129890075401e-6 4.826751505890669e-5 -8.685466850588522e-5 4.102113057445605e-5 -9.827433491307288e-5 -2.9764309100647044e-5 -0.00011055555473729256 -2.831892193069441e-5 -6.812826494644003e-6 6.973318709349611e-5 -0.00014167028525559146 3.295297397118887e-5 0.00011231619469173749 1.4760809836975175e-5 -7.04660032614516e-5 -3.644156636153713e-5 -0.00013884607865189674 -0.00010017020208476827 2.8877997511831343e-5 -0.00017509874119139078; -0.00020911376354581425 -7.536416391906763e-5 -6.166926918614749e-5 -0.0001332242741651709 -3.713777782343708e-5 3.8412833831325205e-5 -3.638194770848067e-5 0.00020342312332410879 0.00011754009160847362 3.8100833492870715e-5 -0.00014813064494467127 7.02401825523172e-5 -9.135743907484453e-5 -2.3750637907715313e-5 7.45775336812396e-5 8.434576060038397e-6 4.24763288295893e-5 5.540105253509776e-5 9.693268524069442e-6 6.30317659067208e-6 -8.08906539047087e-5 3.6141015729200527e-6 -6.828448253747945e-5 1.6866311656211883e-6 -9.788571928459654e-5 8.37825586617124e-5 4.819713053927758e-5 -0.00010719598371600212 0.00017491117746737975 -7.129478630907116e-5 -0.00015710617894603478 0.0001899974099190977; -6.82527480073268e-5 7.176420871314988e-5 1.2672597166149033e-5 -7.850110892988483e-5 -0.00039940284665294316 -8.811072080947753e-5 -3.0617322465042944e-5 0.0001158184949765681 -9.032319268710506e-6 4.981309206255171e-5 8.993351906454118e-5 -1.255612402657291e-6 2.594869505470367e-5 -0.00011819705292668427 -5.481657409937026e-6 9.792483427942776e-5 -0.0001266117415632305 -0.00012811172388114446 -0.00014531742823396148 -0.00010627525813265048 0.0001256267768787913 -0.0003276104490047826 9.187956851589805e-5 -0.00010782758368963505 -1.6017931613303488e-5 -0.00010602922890185798 4.596218599387541e-5 0.0001584808105128586 -1.5635747387680954e-5 3.480896879251378e-5 9.781086367937357e-5 3.459730754753824e-5; -1.247303527889215e-5 9.694774842248403e-5 -0.00011906824449654248 0.00011465150662588318 3.155392612572669e-5 -8.267042660261628e-5 0.000170475643860446 1.9413261431712066e-5 7.101855651599975e-5 0.00013905670555434298 1.5805612453913585e-5 6.435850871383006e-5 -2.5251016020632203e-6 8.46247136698552e-5 6.236372489837554e-5 8.57237325155508e-5 -4.535233831373474e-5 -6.183778834636342e-5 0.00011811016945759617 -6.254099509786025e-5 -5.569632350923402e-5 -3.106490050672278e-5 0.00020598655162503283 5.3631180682889884e-6 4.5976340222499446e-5 -2.3435071236969603e-5 -1.701925914202807e-5 -5.599895605226027e-5 0.00011526288351633423 -0.00022092963565481548 0.00013118646227412012 -4.795727202494773e-6; 7.62522637092722e-5 7.847781179642725e-5 -8.281976389063375e-5 0.0001460146255272653 0.00011454137947347422 -8.018728604603154e-5 -3.6579728270368546e-6 -2.6047347462201043e-5 -5.461285164293497e-5 4.9749344355030474e-5 1.1886130407753417e-5 0.00011439295721410008 -0.00020970407550286124 -0.00012041452046779514 -0.00013609343059889755 -2.448318574116701e-5 -0.00022539470720167815 -5.973656280494814e-5 4.1347392282285115e-5 1.5695840004435145e-5 4.459300416109572e-5 -4.0615109726255564e-5 1.0062330686730213e-5 5.9610129567397996e-5 3.6369695037672784e-5 8.42077011771711e-5 -3.147721273072668e-5 1.5265941867790626e-5 -1.1774219796399575e-6 -7.012624465224891e-5 0.00026326978203777466 5.4683454975554734e-5; -0.00023150456250071774 -4.4311398917431975e-5 7.29628969995858e-5 -2.8592403716843064e-5 -6.419061387781504e-5 0.00017745620575375406 1.8647436852250974e-5 -6.729105146166363e-6 -6.484630135013131e-5 9.5003060457441e-5 -8.654319240371025e-5 4.6085130154705364e-6 0.0001062557145335299 -8.559337723215332e-6 0.00020389133137293006 -0.00011082911655897606 4.125392447100534e-5 0.0001995360450084834 -1.258931754579903e-6 6.779834952545838e-5 -0.00020530899051003626 -8.746122907971831e-5 0.00014620180595308237 -0.00024511808292366994 -0.00014362381475870037 7.610768229575621e-5 5.6208101930009654e-5 -2.9943188885897345e-5 -2.9662445310669613e-6 2.596097114952816e-5 -8.221143648619035e-5 7.732945382237037e-5; 0.0001624862983198369 -4.445436356315338e-5 -9.27571396851374e-5 6.966884244042503e-5 0.00015391206248471638 -4.1241944602978004e-5 4.921891591671225e-5 0.00012304950210539702 -7.626213632016823e-5 -2.860930875892236e-5 -0.0001106736032334628 -8.354699243050388e-5 -4.72269870721536e-6 7.504929668078356e-5 1.4004639849139884e-5 -9.519243907840338e-5 8.232021205634129e-5 8.669103986624727e-5 -8.852242574965063e-6 -6.259146325016251e-5 0.00012959026785841268 -0.00011558221889849218 5.441935804247033e-7 -6.849914793116776e-5 -4.3735768161303786e-5 -0.00017234395785913658 6.968041848898536e-5 -0.00011355197934159924 -0.00010883194194599149 -2.8717240314168328e-5 -4.0194090291204196e-5 3.1114365226070884e-5; -4.708056185742136e-5 0.00012503137877093072 -0.00010857019624784945 3.412844925655083e-5 -0.00011791569814203595 -6.29299112956623e-5 -0.00021243101328255985 -4.840704896605099e-5 -5.207886825198923e-5 -0.00015483459101642115 -1.9092346705714386e-5 1.2115218665315157e-5 2.5165309848957073e-5 4.838421845807528e-5 -6.850701284242705e-5 6.83570386131804e-5 -0.0001362068152998185 -0.0001044182074151599 -7.538141035592394e-5 9.153562750316936e-5 2.397864299883068e-5 -3.3605852153780985e-5 4.418024654617382e-5 -0.00011485977411620353 -7.843618478045414e-5 -8.511719461644647e-5 4.590273234066957e-6 -8.975726923790264e-5 3.698002801278101e-5 -3.0675296515696694e-5 -2.048662027341837e-5 8.826644515967443e-5; 0.00017643371372615295 2.3208749347428596e-5 -5.730214636643523e-5 -4.1981999153354295e-5 -0.000104863858929426 0.00022033816619444546 -0.00011402057155140486 -2.8093258357661833e-5 0.0001222455537332458 7.427425568824808e-5 0.0001491270884558696 -4.485125121065176e-5 -7.085487616347796e-5 3.1248673416151004e-5 -6.810055513270025e-5 0.00012097560809126518 0.00014908600839919905 1.664649405359783e-5 5.279028825828685e-5 7.511992568982649e-5 0.00013074005962483736 1.896690061908626e-5 -0.00013717394197227333 4.8522455266385416e-5 8.341954505670982e-5 8.738327538610483e-7 -9.712915218768304e-5 -5.7676374332336e-5 6.192045091084746e-5 0.00014351493862125328 5.513908015345918e-5 -3.448476511618804e-5; -6.495387900641392e-5 6.845879285111862e-5 2.4338748301738194e-5 -0.0001620030245582796 6.668564198054053e-5 0.00011695789613742034 0.00021469061790709216 2.6416798176171888e-5 -8.992700280641535e-5 9.459584431471147e-5 6.703437862897185e-5 9.976615434723247e-5 -0.0001485184341042377 6.669964092298948e-5 -1.0352871174032705e-5 -2.0531956794638517e-5 -3.9909887379227574e-5 -1.728291975877042e-5 0.0001331210212477446 0.0001285847236097754 0.00011164563217181542 -0.00013239747316833332 0.00015393548416186445 6.45324751197313e-5 5.0883408005859974e-5 -3.2222392171310666e-6 -0.00012428472222081828 0.00010199219719953607 -0.00013944007626988026 8.697630870602246e-5 -0.00010741956178818492 3.0344538357264123e-5; -7.868806383674454e-5 0.00015028773381737913 0.0001341936357169514 1.8898098318750037e-5 -6.647467048941208e-5 3.567744610728677e-5 -3.14717591715375e-5 4.922551565351837e-5 -5.684224484474137e-5 9.868146193576897e-5 -0.00020508652337928755 7.496256770930257e-5 0.00010268605441916836 -6.457578379530464e-5 2.3081060903061187e-5 4.441225868896867e-5 4.664707634612044e-5 -7.757918606847569e-5 -0.00013522664917735397 -1.8174713800987372e-5 0.00010505054419292064 -0.00023647360754300254 0.0001440451077035663 6.667477904949357e-5 -2.3173049299373127e-5 2.8830343984013007e-6 -3.424668207633943e-5 -0.00022863236622987375 -3.988970379912768e-5 -3.415475162623936e-6 0.00010923394333469381 -0.00010998200846705475; 3.416820389032814e-5 3.079927180590273e-5 -1.6100426839649663e-5 5.741807897505901e-5 -2.638916205103244e-5 4.9659812682534597e-5 -9.17024067551608e-5 3.480843722866118e-5 -0.00010099585644854296 -0.00015307003788837071 1.2885738649525831e-5 -4.190936171867482e-5 0.00017781154154147559 -2.867774539697602e-5 -5.1455665491046325e-5 -3.3819408348501164e-6 -0.0001782535241749186 4.3096406240004135e-6 0.00011236418638657144 1.473886775568188e-5 2.710796568646732e-5 -0.00010239905397824271 8.090843901084851e-5 0.00017430693103745017 1.4507150832501497e-6 -1.0459734332799193e-5 -9.894000552082426e-5 -7.03655446362754e-5 8.246312197810379e-5 -0.0001695905798645727 -9.585459931473482e-5 -7.262241572095459e-5; -0.00015964906846760676 4.857807003305767e-5 -5.4873547145456385e-5 -0.00014450480297795428 -4.210987676173006e-5 -0.00012539225930935263 -4.662272398831759e-5 9.072580074304245e-5 4.2765529411791615e-5 -0.00013549140489316067 -4.658120530985959e-7 0.00017217487844050692 1.719144971162906e-5 0.00014467631340836604 8.719859839062159e-5 9.384461787941508e-6 -6.049629649839677e-6 0.00015084510310431465 6.2414831387526e-5 2.4025394719740984e-5 3.60579550545151e-5 -4.17519396648889e-5 8.103031956362244e-6 -1.6674154852568715e-5 3.343506328407217e-5 -6.749087983165834e-5 9.771886184922655e-5 -1.5487269308654385e-6 -8.478014855767466e-5 -4.3941955613024745e-5 -9.064599648239997e-5 0.00015966868600686431; -7.518077395601692e-5 0.00021259960417424498 -8.823449675204497e-5 4.3982240504338335e-5 7.477962285534671e-5 0.00010692400323947155 7.3812060237874595e-6 4.427395183498217e-5 -7.993172689552097e-5 -1.8705209322290666e-5 4.1242136377697264e-6 1.710444016695188e-5 -0.00012053396417272028 -9.184935984180107e-5 0.00010675458256642876 -0.00013552966188384814 -1.332903516398978e-5 6.633052953618553e-6 5.039394532119122e-5 -0.0001410190954173616 -7.295805653341267e-6 9.91406858249005e-5 -1.1474411758027133e-5 -0.00018028598346958656 0.000139445991271857 -4.101712252107894e-5 6.595184777469747e-6 -5.190476363153949e-5 -5.976361276210804e-5 3.408484631123278e-5 2.9387877339250665e-5 3.955880752087826e-5; 8.097056630527013e-5 0.00015928577991896045 -1.1265278626305848e-5 -7.95398817351716e-5 -0.00014184292453113377 -5.029503230592621e-5 -9.803637007259373e-5 -1.6087697892858158e-6 8.6188672103397e-5 7.251102866600007e-5 1.5591024894322623e-5 -9.378022590659196e-5 0.00018773052503116714 0.00013281657371432102 8.920972974032934e-5 -0.0001667330005509417 9.291851816693702e-5 -7.766045095171176e-5 -6.0670555139709525e-5 0.00021219946862426192 -4.465422247424991e-5 -5.2662989073516257e-5 -3.13120952703176e-5 1.8998885170971003e-5 -1.9281588574571368e-5 -7.629120304022805e-5 5.535594678827429e-5 6.445647810348002e-5 0.00015292129606700954 -6.039846706675105e-5 7.297106563805691e-5 -7.159117867321131e-6; -3.915686619426013e-5 -2.5472903587673542e-5 -3.054136112755512e-5 3.839765332681285e-5 -4.199095354839623e-5 -1.1827983769638615e-5 4.912733703545275e-5 -1.1185975097636284e-5 -3.876806084723334e-5 -2.1429397637476224e-6 -0.00010326468844856054 1.8300199254976524e-6 9.062566306250832e-5 0.00014975590812549552 -3.8895608384205427e-5 4.72103175719867e-6 -4.175899965763407e-5 5.7741252305401095e-5 5.0514614769905716e-5 -2.4327329527171432e-5 -0.0001240132311867128 -1.7706269943889825e-5 -8.957234757677692e-5 -3.005375555209234e-5 5.27753020943564e-5 -9.99690798270029e-5 2.4764855254161215e-5 4.783884133952248e-5 -5.025183571693426e-5 8.438821743051262e-5 2.0081735353257067e-7 -1.6331905215212927e-5; 1.6624160987897704e-5 0.00012956387063987673 -2.070445924963053e-5 -5.4022666237675365e-5 6.746819263590647e-5 -0.00010586521392887307 0.00014648640927427463 0.00010553261227706576 -5.2910194146341304e-5 0.00013290728957891584 -0.00014006620921626248 -7.293810057652409e-5 1.4726231991709102e-5 -0.00010275971152786008 -4.59701602192443e-5 4.837850275017649e-5 -9.491088139092353e-5 -8.69142692788863e-5 -4.357483129914695e-5 0.0001256731396171839 2.6804317083779056e-5 -6.285328135325462e-5 -0.00015067436624284467 -0.00013366835782270702 5.6544695606326005e-5 5.1112502031368775e-5 -7.029708803364672e-5 5.715151047134902e-5 0.00018196553565595728 -2.0401964949804875e-5 5.777648976455971e-5 -7.289228387143235e-5; 1.8042144675531566e-5 -8.147228199259099e-6 3.669265048370872e-5 -0.00015387533143369794 4.021861228536143e-5 -0.0001094990185622287 -2.4812517397052085e-6 -5.4666379709697226e-5 3.2453517506625276e-5 1.2886975001283752e-5 3.628304681184465e-5 -0.00010110748202281754 0.00014104258649310223 0.0001228216823262931 4.122706001067971e-5 3.277005076667248e-5 6.91476167118586e-6 -5.91349090784481e-5 0.00011294388592502348 -7.635060695904993e-5 1.4029755094061392e-5 5.4532400652106e-6 -2.8673635951449313e-5 0.00012895365567671345 5.12652367076681e-5 -0.00011615812776584549 0.00014292844194712057 0.00011595571037175421 -5.814286863754718e-5 -7.700431537089747e-5 0.0001729217727741002 -0.0002247168632850225; 5.598975857016796e-5 2.4207918925423666e-6 2.9711291594381566e-5 0.00016417375316838857 1.6056263652240648e-5 -4.089402265501517e-5 -0.00016605412512557213 0.00017448576416816293 -0.00011740427854940846 -2.6518021891864724e-5 1.9842875721470515e-5 3.923802304076761e-5 -3.6965278391767696e-5 0.0001548652567946879 4.740519815119879e-5 7.681087898551846e-5 8.141861925130339e-5 3.1869766266277924e-5 -6.45884660259293e-5 6.892744826723731e-6 -0.0001462372833118241 0.000143596923029928 -4.074540394479102e-5 -4.47886426770535e-5 -5.7920259821219586e-6 3.872276518838552e-5 -6.144875207221716e-5 -4.6667022084624486e-5 -4.2716293515449144e-5 1.0570691833231093e-5 -0.0001245892866934178 4.5512321398086375e-5; 0.0001336553989639963 -0.00019328966911787477 -6.316989922562912e-5 -6.689801987162626e-5 -0.0001795848645306943 5.155583246589623e-5 1.991799333700527e-5 -4.812691506602089e-5 7.689316840642471e-6 2.63166032490708e-5 -2.560867026117872e-5 -8.283671806932186e-5 -6.542260846260322e-5 -7.963407435825033e-5 3.5682908925052764e-5 -8.579804420204212e-6 -8.547311037916308e-5 2.2309196789067605e-5 -0.00012304098229294624 -2.8429764574025658e-5 -0.00020388915320297996 0.00011940523224985653 4.322595602929608e-5 -5.099645452362297e-5 -9.221963293755738e-5 -0.0001006059598912903 -4.001460795194158e-5 3.466801470123962e-5 1.1321676789476702e-5 -1.0078067793136257e-5 1.1107398253080842e-7 1.9005999343727747e-5; 8.009597003426923e-5 8.859990556692148e-5 9.733113473946074e-5 -5.573516465058095e-5 0.00012154748618803296 1.7488439326084422e-5 -2.8435909925370127e-5 -1.0670230594145793e-5 0.00012157669915785004 -8.417631543972153e-5 -8.697989514369026e-6 -7.055285605868378e-5 0.00010530799262658828 0.00012774680580213016 8.736840608496425e-5 -1.9671724302788426e-5 -0.00015636942917746297 -9.217168336300776e-5 0.00012078070211890443 -0.00010406411966806918 -3.668870122902654e-5 -0.000107215686156758 6.959939709366105e-5 2.109150626955335e-5 -9.943516270988246e-5 -2.7679286731040092e-5 0.0001011849038974835 0.00020617689951057892 4.4102233360424956e-5 -0.00016865738653843172 1.3939611918135891e-5 -8.568917620261076e-5; -1.9764056650568438e-5 0.00016394102558539597 -8.680242908605061e-5 -0.00014127285257866125 -7.2859962522975506e-6 -2.3878885348797995e-5 4.331385198446886e-5 0.00013790143557404543 6.636773459892704e-5 7.294916567904288e-5 -2.7636642789115413e-5 -9.839439016951995e-5 -6.417682025069654e-5 -4.6029894382930336e-5 -4.4378608526421196e-5 -1.097520007472461e-5 0.00018886385796637847 -6.758517519605307e-6 -0.00012948690718925102 -0.00010354400737863512 0.00010292071956488657 -5.853799311744753e-5 -8.075962987216058e-5 5.8219404218714205e-5 -0.00017127516193735066 -4.807420564186172e-5 -1.1529049606294859e-5 0.00019305131709249142 6.370388282509835e-5 -1.6957594419462947e-6 -0.0002056711892685793 -4.008851838019014e-5; -0.00016451621703096646 -6.20748010835775e-5 3.3340960878733463e-6 0.0001198499543686312 0.0001243506361940717 -5.385867696829783e-5 5.994650420393004e-5 4.542742537225837e-5 -0.00011291158951578022 9.705864082865717e-6 -3.553280256098158e-5 6.433529999116051e-5 -3.1394288801742034e-5 -2.5662902392718667e-5 -4.819909880394739e-5 0.00018470880824382151 -4.0358985264241574e-5 -0.00023431765074135482 6.716883437250108e-5 4.9830038203003975e-5 -0.00018061684894135 -6.8862454630338e-5 -3.031193553694708e-5 -0.0001900872935796309 4.060521897983955e-5 8.149243732690032e-5 0.000173720409672323 -0.00011103526556623855 2.34109385631348e-6 5.621932943243058e-5 1.9283261660106276e-5 4.1039593649408634e-5; -3.5486452634588756e-6 0.00014384404700646365 9.069888865254701e-6 0.0002145464334443892 -1.2162974232773962e-8 -0.00014184699623914155 -2.3224543592741854e-5 7.868417948881385e-5 -8.042804777690405e-5 -1.8247985035519067e-5 4.995477708687232e-5 -3.967478708547581e-5 1.6246371405911783e-5 8.329844849836838e-6 5.497176450260567e-5 -3.440537487224833e-5 1.4144207622551735e-7 -6.082370030529554e-5 -0.00017639191876699854 2.1533356875775515e-5 -5.736995238297421e-5 0.00015262874719154982 -3.418726349084693e-5 -8.132272135302076e-5 -4.137810331488767e-5 -0.00011322518342242197 -0.00016458084221997432 6.053201805084186e-6 3.448942722568409e-5 6.188904287225135e-5 -8.01135663368979e-5 -1.7205804151701454e-5; 0.00011260161289654966 -5.576447921509265e-5 7.857174813220859e-5 6.935809386603495e-5 -7.72161450312261e-5 7.586571763215787e-5 0.00010433031155664544 -0.00020737471760537003 -7.050542018432034e-5 -6.423943824653844e-5 5.546486016654912e-5 -1.5253030427578532e-5 0.0001722608855617355 4.016871997566992e-5 -3.149311788904863e-5 0.00018649031579671404 9.879232555806308e-5 8.1572396708116e-5 0.00026410625895198595 -1.3369441434289992e-5 -5.871873628517974e-5 0.00014210831312631183 0.00011656317536662803 0.00022888755364522262 -0.00012241066844023173 7.795614391041467e-5 5.852084603803123e-5 -8.451952572965547e-5 3.202992473788283e-5 0.0001550091243972185 5.7139425449421965e-5 -4.679756722351298e-5; -0.00014854491070968943 7.591558772015952e-5 4.979212410738497e-5 -3.0242207296626045e-5 -4.2308810018073004e-5 2.5183382083708062e-5 0.00011699028930500053 3.1921957477392295e-5 2.3413296210269746e-5 -9.100888357566556e-6 5.522541635195881e-5 -0.0001558455338201953 9.667871147736109e-5 8.616613870595975e-5 5.559905496533478e-5 8.611139440086363e-5 0.00010408348992110473 -7.648678781355239e-5 0.0001053645022946172 -3.8337479574908996e-5 -2.885412934334244e-6 1.8401762124742274e-5 -8.278960792472114e-5 -9.701638971113266e-5 7.764793502043568e-5 7.879854768156826e-5 1.2834507677102373e-5 -2.7907032356084753e-5 0.00013341807737356264 5.2922601232930215e-5 -7.752159633333068e-5 -2.550351980266046e-5; -0.0002136659417951787 3.943297121128239e-6 -3.2221363165698837e-5 0.00011181344283294849 5.406329049625039e-5 1.4841910830418352e-5 -1.4544075051695514e-6 7.127642253596754e-6 -6.887757941833457e-5 -4.905020201799194e-5 -2.225686827251666e-5 -1.3017855475012485e-5 -0.00019908710552366134 -9.576891485588892e-5 1.835468404943181e-5 2.6265287066487545e-5 -7.491676248131253e-5 -1.1417861481770195e-5 -5.3151903069683696e-5 -3.393635730712066e-5 -0.00012196291523973167 -0.0001862119638312372 -0.00022681742435772193 0.0001997570694006407 -5.1654158108743895e-5 7.554362358541539e-5 -0.00010477638303520657 -0.00022782651236733967 -4.573423890498099e-5 -1.2993203621165352e-5 7.513753056309551e-5 0.00018934091034358724; -0.00010067832829693484 -3.940976173478804e-6 -0.00010592665114702907 -2.7050783162421294e-5 -2.46885325646219e-5 -8.358098478086343e-5 0.00012999843545227928 5.4111040958303895e-6 6.908661395979939e-5 0.00011320505961745938 0.00014282309655070158 -7.956300453545759e-5 -0.0001690988652876772 1.0778955013367284e-5 -2.00568562172792e-5 -3.8434850127968156e-5 -9.504095999321034e-5 -9.479327334273128e-6 -4.707455866767144e-5 0.0001247913963856916 -8.323175336729776e-5 7.494108228681262e-5 0.0002196537019449088 -2.6913933313633895e-5 -0.00012417941673222965 -6.323948605582557e-7 2.143963263049145e-5 0.00014800676524143625 2.8760357392885234e-5 9.690911190088762e-5 1.6894751181526224e-5 -0.0001251719883180417; -7.391299387664566e-5 4.013035130674775e-6 -0.00020673247146175787 -0.00016476080437415004 -0.00011633368213656112 -0.00011271557396971321 0.00013387873676852128 1.728537225867669e-5 -6.846012203016079e-5 9.857708554486668e-5 -3.122193103182356e-5 6.524638937564465e-5 -0.00010678708732588064 2.049568301013601e-5 -2.6225361772132305e-5 -8.556680787580629e-5 -9.384057576529817e-5 -2.0398640309528974e-5 8.742750996030277e-5 7.944082234512149e-6 -0.00010026846958785364 -0.00021170442433780177 6.683448354022443e-5 3.6620451090909504e-5 -2.375452296396316e-5 -8.982533855587508e-5 -0.00019141751200652213 4.7802397410898244e-5 4.585310652045165e-5 5.417252012999955e-5 -0.00011780766018191552 -9.185765176074376e-5; 3.8876030420056846e-5 -3.09381980617943e-5 -9.406764412357832e-5 6.94074073986726e-5 -6.131290784996298e-6 3.822248935523095e-5 6.902496123860539e-5 -0.00010913899724198885 -2.0571126252776984e-5 -1.9289693692717046e-5 4.516394209406435e-5 -5.253947939433585e-5 9.736395030688257e-5 5.242243740996942e-5 1.3216560990830023e-5 4.8132885684593746e-5 4.013408714729082e-5 -1.3559325917758034e-5 -5.0248924263886804e-5 5.6036052672989956e-5 -3.6083798942290485e-5 8.487633339303635e-5 -8.468923159631035e-7 -4.222558755953178e-5 2.224531589365165e-5 -6.247648390069384e-5 -0.00011952638716263762 7.498336662012706e-5 -4.2562227928442115e-5 -2.2542950783988885e-5 -0.00012716258471425388 -7.77966529416316e-5; -2.675249695140807e-5 6.291309776303569e-5 4.982039766209795e-5 -0.00015119554652311639 -4.008606399123204e-5 0.00015823453079962552 -5.293600540434268e-5 2.9592920809469937e-5 -8.794531228897826e-5 -0.00016994023232667503 3.403870371574508e-5 5.22262039611619e-5 3.119732585207361e-5 4.748680888030167e-5 -3.0563367063321164e-5 -0.00011369123009910472 4.22894577512764e-5 -0.0002596070821438234 -0.0001287386525929214 0.00014614859767777305 -1.7066978644046608e-5 -0.00012042563091207093 -7.09726714104548e-5 0.000149536356302514 -1.0655612159220259e-5 -4.81794863572865e-5 -7.692391724969822e-5 -2.9110780333333724e-5 -0.00014570427944828846 -2.0346600167729667e-5 0.0001348050741670601 8.860726896261623e-5], bias = [2.1219726817136153e-9, -2.7047907512239253e-9, -9.227247859454673e-10, -3.2978839481491646e-9, 3.158381535854506e-9, 1.9880880092105761e-10, -2.968160854981151e-10, -1.158371623119528e-9, -4.1483913231586815e-9, 3.5577373969175492e-9, 2.5454604051546816e-9, -7.795785561594324e-10, -1.6284706148100028e-9, 5.028198427427688e-10, -3.5574870946581687e-10, 1.6945343904394958e-9, -7.473207588823518e-10, -4.599240038028779e-10, 7.674868994592897e-10, 6.753764648407899e-10, -4.058823905342033e-9, 1.0571881700945696e-9, -1.3170386723226517e-9, -9.80056279977105e-10, -1.271228061951345e-9, 5.412117509622856e-9, 1.8985044920185965e-9, -4.230481448743279e-9, 1.5222437641000436e-10, -5.067896517753051e-9, -7.025950069552352e-10, -2.3114536216465203e-9]), layer_4 = (weight = [-0.0005800689645775102 -0.0007139370045078659 -0.0006993937213374705 -0.0006793257948758238 -0.0006379823258228933 -0.0006722661168550341 -0.0006846551009552927 -0.0007181141435194288 -0.000682804576604014 -0.0006842329680702028 -0.000688798799858485 -0.00086468620587217 -0.0008109700352618887 -0.0005344519919692209 -0.0009347007927464384 -0.0006449919258506389 -0.0006315832474572034 -0.0005538052833562058 -0.0007969079914434793 -0.0008172104843616806 -0.0007545810777475232 -0.0007846753707610751 -0.0007818773432832299 -0.0006773498411333549 -0.0006173134276369011 -0.0007275826145717864 -0.0005645511455792566 -0.0006330695026753961 -0.0006619615276924456 -0.0005550222815133367 -0.0007631511738871792 -0.0006508642187823361; 9.872318649164789e-5 0.00010324428215345283 0.00021678619486045555 0.00019706848355085642 0.00015169874211419656 0.0002541420115112627 0.00028693937094723544 0.00026084661062728443 0.00048054597846747786 0.00031825407073012875 0.0004217883308530572 0.0003281723363770098 0.00019737819182944958 0.00022445833699098706 0.00016078547959579414 0.00017468925262878186 0.00011844260442566738 0.00020490594016843488 0.0001833662355982466 0.00028896964551477133 0.00010787676487700777 0.00011713001680485277 0.000255179980328889 0.00027383208454490477 0.00023982284574540315 0.000134711877340374 0.0004017738993329732 0.0004939646162000657 0.00018792398023951103 0.00027143326237118816 0.0001293488286563957 0.00015263522340889618], bias = [-0.0006883451945975125, 0.00025850668463625544]))
Visualizing the Results
Let us now plot the loss over time
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")
lines!(ax, losses; linewidth=4, alpha=0.75)
scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)
fig
end
Finally let us visualize the results
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
s3 = scatter!(
ax,
tsteps,
waveform_nn_trained;
marker=:circle,
alpha=0.5,
strokewidth=2,
markersize=12,
)
axislegend(
ax,
[[l1, s1], [l2, s2], [l3, s3]],
["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
position=:lb,
)
fig
end
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 128 × AMD EPYC 7502 32-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
JULIA_CPU_THREADS = 16
JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 16
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.