Training a Neural ODE to Model Gravitational Waveforms
This code is adapted from Astroinformatics/ScientificMachineLearning
The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl
Package Imports
using Lux,
ComponentArrays,
LineSearches,
OrdinaryDiffEqLowOrderRK,
Optimization,
OptimizationOptimJL,
Printf,
Random,
SciMLSensitivity
using CairoMakie
Define some Utility Functions
Tip
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector
function one2two(path, m₁, m₂)
M = m₁ + m₂
r₁ = m₂ / M .* path
r₂ = -m₁ / M .* path
return r₁, r₂
end
Next we define a function to perform the change of variables:
@views function soln2orbit(soln, model_params=nothing)
@assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
if size(soln, 1) == 2
χ = soln[1, :]
ϕ = soln[2, :]
@assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
p, M, e = model_params
else
χ = soln[1, :]
ϕ = soln[2, :]
p = soln[3, :]
e = soln[4, :]
end
r = p ./ (1 .+ e .* cos.(χ))
x = r .* cos.(ϕ)
y = r .* sin.(ϕ)
orbit = vcat(x', y')
return orbit
end
This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d_dt(v::AbstractVector, dt)
a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
b = (v[3:end] .- v[1:(end - 2)]) / 2
c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
return [a; b; c] / dt
end
This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d2_dt2(v::AbstractVector, dt)
a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
return [a; b; c] / (dt^2)
end
Now we define a function to compute the trace-free moment tensor from the orbit
function orbit2tensor(orbit, component, mass=1)
x = orbit[1, :]
y = orbit[2, :]
Ixx = x .^ 2
Iyy = y .^ 2
Ixy = x .* y
trace = Ixx .+ Iyy
if component[1] == 1 && component[2] == 1
tmp = Ixx .- trace ./ 3
elseif component[1] == 2 && component[2] == 2
tmp = Iyy .- trace ./ 3
else
tmp = Ixy
end
return mass .* tmp
end
function h_22_quadrupole_components(dt, orbit, component, mass=1)
mtensor = orbit2tensor(orbit, component, mass)
mtensor_ddot = d2_dt2(mtensor, dt)
return 2 * mtensor_ddot
end
function h_22_quadrupole(dt, orbit, mass=1)
h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
return h11, h12, h22
end
function h_22_strain_one_body(dt::T, orbit) where {T}
h11, h12, h22 = h_22_quadrupole(dt, orbit)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
h11 = h11_1 + h11_2
h12 = h12_1 + h12_2
h22 = h22_1 + h22_2
return h11, h12, h22
end
function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
# compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2
@assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"
h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
@assert mass_ratio ≤ 1 "mass_ratio must be <= 1"
@assert mass_ratio ≥ 0 "mass_ratio must be non-negative"
orbit = soln2orbit(soln, model_params)
if mass_ratio > 0
m₂ = inv(T(1) + mass_ratio)
m₁ = mass_ratio * m₂
orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
return waveform
end
Simulating the True Model
RelativisticOrbitModel
defines system of odes which describes motion of point like particle in Schwarzschild background, uses
where,
function RelativisticOrbitModel(u, (p, M, e), t)
χ, ϕ = u
numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
denom = sqrt((p - 2)^2 - 4 * e^2)
χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = numer / (M * (p^(3 / 2)) * denom)
return [χ̇, ϕ̇]
end
mass_ratio = 0.0 # test particle
u0 = Float64[π, 0.0] # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4) # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize) # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e
Let's simulate the true model and plot the results using OrdinaryDiffEq.jl
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)
axislegend(ax, [[l, s]], ["Waveform Data"])
fig
end
Defining a Neural Network Model
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model
that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but in case you do use them, make sure to mark them as const
.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl
,
const nn = Chain(
Base.Fix1(fast_activation, cos),
Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-0.00016973146; 0.00010228566; 8.145117f-6; 6.4114465f-5; 5.702867f-5; 4.76904f-5; -8.669598f-5; 4.466077f-5; 0.00012299528; -0.00013257282; 2.7508391f-5; 9.581404f-5; 6.930337f-5; -1.6895068f-5; 9.430357f-5; 1.4941846f-5; 1.629597f-5; -1.7224564f-5; 3.4461067f-5; 7.419241f-5; -0.00016886815; 7.546448f-5; 7.958376f-5; 3.6587928f-5; 0.000119423166; -9.249538f-6; -0.00015180308; -3.2577173f-5; 0.00011929801; 5.5948887f-5; 4.750216f-5; -0.00012370358;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-0.00016769653 -2.7053862f-5 -0.00017082065 -0.00013386238 -0.00012508457 -2.2359136f-5 4.803735f-5 -1.834314f-5 -2.3634315f-5 0.0001629959 -0.00010172413 -7.4161806f-5 -3.979948f-5 -8.33191f-6 4.9077447f-5 2.6626702f-5 0.00013942784 -5.5488006f-5 6.507954f-5 -2.5138675f-6 -5.77159f-5 -2.330059f-5 0.00010912463 0.000152089 -0.000167332 -3.624775f-5 4.9033817f-5 -0.00013563607 0.00010465832 -3.232244f-5 2.710944f-5 5.413418f-5; 1.6381066f-5 -0.00011741099 5.4537275f-5 -9.254288f-5 -1.2287291f-5 -6.455412f-6 0.00012708895 0.00018568501 -1.6488117f-5 2.3316788f-5 -9.7968055f-5 8.780118f-5 -6.6398956f-5 -1.1050335f-5 3.2228178f-5 0.000109622524 -5.5525747f-5 -1.6263253f-5 1.3382361f-5 0.00023915576 -9.538718f-6 0.000219129 -8.315922f-5 -0.00015717986 -3.8555805f-5 8.827862f-5 0.0001116123 7.4829644f-5 0.000120325094 -1.6847967f-5 0.00013546148 0.0001787274; 0.0001612122 0.00012128526 0.00010605269 -2.9618752f-5 -0.000115038456 -9.620825f-5 -0.00019093987 3.0776353f-6 -8.025605f-6 1.4730061f-5 1.051255f-5 4.5201523f-5 4.5591034f-5 7.1813614f-5 -0.00021326578 3.8724895f-5 -0.00011551993 4.7922582f-5 -5.4708064f-5 -0.00016613574 6.794249f-5 0.00016775075 3.8174818f-5 0.0002002555 -0.00015533742 -0.00012101596 -5.440493f-5 6.2229876f-5 5.5554454f-5 -0.00019473345 -0.00018833019 0.00010113753; -7.903826f-5 -0.0001946152 -2.9831795f-5 3.0625408f-5 -6.4318624f-6 0.00014819577 8.41512f-5 -0.000110261906 6.129719f-6 0.00016690549 -5.5277683f-6 -7.4071636f-6 -4.5201054f-5 0.000108418724 -5.2511234f-5 -7.990368f-5 8.8340465f-5 -1.6190564f-6 3.1085714f-5 -9.329487f-5 -3.0292043f-5 5.8949463f-5 -2.5343446f-5 0.00016575087 1.6501492f-5 -4.6059205f-5 -2.088261f-5 -0.0001052631 -6.352659f-5 -9.46828f-5 1.1019924f-5 6.3403226f-5; -9.061148f-5 0.000109711225 -0.000110999055 -3.3587103f-6 1.1241367f-5 -4.2724205f-5 7.142609f-5 0.000200749 -5.2439147f-5 -5.8536425f-5 0.00012446585 0.00025077802 -0.00013295746 -2.689875f-5 -0.0001395383 2.8880111f-5 8.221755f-5 -7.809975f-5 -0.00011476257 -9.664306f-5 7.887839f-5 -2.1453516f-5 -0.00018082904 2.5603624f-5 -0.00013389306 -2.0121446f-5 7.244338f-6 -1.3616092f-5 -1.3422798f-6 -0.00018805382 -3.2437692f-5 -0.000119407894; 5.2717947f-5 3.2133375f-5 -0.0001245017 -2.1593993f-5 -0.00011957095 -9.386015f-7 5.101488f-5 6.3298845f-5 0.00010547548 2.0019932f-5 2.2806224f-5 -8.174248f-5 -0.00016300187 0.0001257158 4.1105136f-6 3.1516884f-5 -3.22701f-5 -2.0359328f-5 2.6382833f-5 0.00010257602 6.4029984f-5 5.4013337f-5 -0.00017931231 5.2067073f-5 5.3026556f-6 -0.000133744 0.0001259458 0.00011074252 3.1218744f-5 2.3609507f-5 3.278697f-5 2.0932332f-5; -5.9503636f-5 0.00028371872 7.656597f-5 1.6952556f-5 -0.00013300958 -8.08706f-5 2.8496156f-6 0.00024047025 -5.975658f-5 0.00019887589 6.0197824f-5 -5.8530328f-5 6.902113f-5 8.792738f-5 9.2044094f-5 -2.7587304f-5 -9.2378184f-5 3.654951f-5 0.00016660463 0.00012311428 2.623431f-5 -2.8834304f-5 0.00010822345 -0.000116269446 -0.00019732933 -2.5959163f-5 8.404774f-5 -0.00017199952 9.4427705f-6 -8.4784886f-5 -8.800568f-5 -0.00010033761; 0.0001241156 -9.869965f-7 -0.00011090412 0.0001313972 7.117672f-5 3.117375f-7 -9.591258f-5 2.9992805f-6 -1.8227083f-5 2.535013f-5 0.0002689924 5.348652f-5 2.8266099f-5 9.696974f-5 -0.00011567071 -0.00010429217 9.8355966f-5 -4.9239384f-6 -3.0191042f-5 -1.2781587f-5 -9.5767515f-5 -0.00012653176 -0.00019449383 -0.00012360964 0.0001319568 -7.117445f-5 -0.00010526782 0.00010099342 0.00014840218 -8.72102f-5 -5.4250613f-6 1.10883975f-5; 0.000101234575 -4.1177907f-5 -4.7842383f-5 1.4110997f-5 -7.27403f-5 -1.9906642f-5 0.00016979239 8.74302f-5 0.00010589976 -0.00010506188 -0.00012354516 5.1799143f-6 -0.00020537751 6.401821f-7 -4.3091575f-5 -6.6190056f-5 -2.369298f-5 -0.00015082367 -0.00012527655 -3.143154f-5 6.650369f-5 2.3701497f-5 0.0001924084 -8.8206085f-5 -0.00012317981 -9.7998745f-5 5.9200378f-5 -1.8463066f-5 5.4697848f-5 -6.026357f-5 -6.0754093f-5 -9.641196f-5; -9.7301534f-5 9.4413135f-6 -4.0601775f-5 -4.4820357f-5 -0.00013376813 8.512471f-5 4.201088f-5 -9.0759895f-5 -1.3368883f-5 -3.6679947f-5 -3.3092256f-5 8.4975625f-5 -1.6512802f-5 6.265098f-5 0.000105276675 5.747504f-5 -8.407184f-5 7.2301073f-6 0.000119908254 -0.00022196956 -9.330602f-5 -0.00013434725 0.00011878384 7.870482f-5 -6.6976514f-5 0.00011379847 1.267171f-5 0.00014010209 -5.0890812f-5 -4.8552247f-5 4.298911f-5 -7.004646f-5; -0.000107202475 -0.00013929911 7.045441f-5 -3.9763872f-5 5.8731814f-5 0.00011556638 -6.8454483f-6 0.00010908247 7.881514f-5 -0.0001243405 3.5713992f-6 -0.00013350615 3.13364f-5 -0.00012380138 -3.2567827f-5 5.012618f-5 -2.9728919f-5 -0.00013740669 -4.267815f-7 -0.00010447329 -2.5108144f-5 -8.786579f-5 -1.7541392f-6 -9.6108306f-5 9.540902f-5 -0.00013115874 2.4406208f-6 -2.3948272f-5 -1.8197738f-5 -2.1174234f-5 0.00011619745 -0.00015724862; -7.3536057f-6 -4.06912f-5 -0.00018162104 -0.00014059395 -0.00016906956 4.918992f-5 -8.071423f-5 -5.636761f-5 4.459829f-5 -0.00013532047 -0.00012732367 -1.7014352f-5 9.331706f-5 -1.9855992f-5 3.3769196f-5 0.00018180661 0.00015814372 -5.118432f-5 -2.5271775f-5 -9.4229246f-5 -7.964317f-5 6.6557186f-5 6.498607f-6 -4.2749674f-5 -1.3890157f-5 -0.00012996946 -8.178385f-5 8.860648f-5 -4.018831f-5 -0.00013451141 7.5904616f-5 -0.00010433804; -0.00012956912 0.00016967692 2.3956452f-5 8.8744026f-5 -1.2909095f-6 0.00017668329 -7.923512f-5 -1.937524f-5 -0.00016417119 -6.884761f-6 0.00013588414 4.892569f-5 -0.00016490486 0.00010132075 -0.00018364833 -0.00014845731 -0.00033146937 1.9701545f-5 -7.414098f-5 0.00013074353 -1.1823386f-5 0.00016250297 5.240079f-5 4.2634932f-5 2.8285893f-5 -9.032822f-5 -6.106634f-5 -6.104349f-5 -0.000114093906 -2.7973003f-5 7.091321f-5 6.465503f-5; 4.2338717f-5 4.6096866f-5 -1.5276022f-5 -1.7098935f-6 -2.0595007f-5 0.000111610156 4.0228766f-5 -2.9077294f-6 2.7725617f-5 6.371609f-5 -0.0002645837 -5.2875643f-5 1.5182551f-5 2.8752997f-5 -3.7391044f-5 -0.0001063914 -0.00013660964 -5.0820385f-5 -0.00013373164 3.558385f-5 -6.2006725f-6 -9.249124f-5 -7.999738f-5 0.000113075526 -9.400236f-5 -3.3799595f-5 4.2381584f-5 1.734393f-5 6.3045925f-5 -2.7545962f-5 5.4402426f-5 0.00013743744; -9.652085f-6 -0.00013605242 3.1954708f-6 3.3060976f-5 8.9758754f-7 -2.5902427f-5 3.3561315f-5 -2.1114967f-5 -1.9376932f-5 6.910386f-5 -0.00012579752 -4.179472f-5 7.829575f-5 -2.2852f-5 0.00010480925 -0.00013302796 6.329709f-5 -0.00012853277 0.00012616478 9.8922086f-5 -2.643222f-6 7.035522f-5 -6.326776f-5 0.00015588786 0.0001697359 -4.324973f-5 4.976359f-5 -0.00012985378 9.4879346f-5 2.3129927f-5 5.5580393f-5 -0.00011672666; 9.686824f-5 -0.00010667465 4.360838f-5 5.157815f-5 0.00014342005 0.00012896415 -6.421653f-5 -2.6473606f-5 -4.862456f-5 -8.432468f-5 0.00012724158 -6.623511f-6 -4.7103407f-5 -2.4439096f-5 -0.0001282653 -0.00022639614 -1.4352391f-5 0.00016366084 3.0270847f-5 5.0351293f-5 -2.4752195f-5 -0.00014802237 3.4854085f-5 2.8173961f-5 7.569831f-5 -5.2400428f-5 -5.27727f-5 -0.0001384509 2.202846f-5 4.153753f-5 -4.2336087f-5 -4.702022f-5; 4.7801197f-5 4.9071998f-5 0.00010009839 -0.00017751318 -6.0194176f-5 -1.0641174f-5 0.00017997052 3.5234993f-6 -6.708185f-5 3.0930943f-5 7.942126f-5 1.0850849f-5 0.00011718841 6.309329f-5 -0.00014679029 4.681103f-5 -4.2138876f-5 6.615478f-7 2.4406601f-5 -1.6657328f-5 2.050794f-5 -5.8478894f-5 -1.6188664f-5 5.0752034f-5 0.000112773814 0.00015356607 9.7418095f-5 -0.000107475214 -1.951616f-5 -0.0001749518 6.147433f-5 -5.490481f-5; -0.00012842107 4.245299f-5 0.00020506604 -2.1805285f-5 0.00011909729 7.375458f-5 -0.0001090639 -0.0002096534 9.863692f-5 -0.000100000536 -3.674996f-5 0.00013129444 -8.7627785f-5 -2.1309506f-5 -3.2921784f-5 -9.394566f-6 8.754906f-5 4.9508963f-5 -0.00014450005 -1.8693054f-5 1.8630099f-5 0.00014168731 -3.9662667f-5 4.3920736f-5 -4.8898874f-5 0.00016517285 5.4698383f-5 -0.00012118812 9.228276f-6 5.6882433f-5 0.0001383227 2.0800386f-5; -0.00011982376 -0.00024035302 -1.611136f-5 -1.5602546f-7 -0.0001864871 0.0001572566 1.6930911f-5 -4.826473f-5 -0.00013262029 -6.205309f-5 -5.257816f-6 -4.754955f-5 -1.4382929f-5 3.045013f-5 0.0001869962 4.859692f-5 2.6820619f-5 1.6504127f-5 3.707935f-5 -0.00024769266 -7.536935f-5 -9.819911f-7 7.8764744f-5 -0.00013835798 -2.8951386f-5 2.4836614f-5 2.32962f-5 0.00014605513 -7.029452f-5 -9.097774f-5 -3.4682926f-5 -0.00019510367; -2.0243431f-5 -3.9816165f-5 -3.6422112f-5 -6.97657f-5 2.756036f-5 -4.205604f-5 5.0818155f-5 1.9001542f-5 -2.5512441f-5 -1.9108056f-5 4.851034f-5 -9.211271f-5 2.4196028f-5 -0.00013067745 -2.2233753f-5 -6.386729f-5 -0.00014642438 -0.00012086219 -2.5003492f-5 7.166674f-5 -6.773163f-6 0.00012746977 -0.00019612556 -7.966778f-5 -6.542401f-5 2.0918065f-5 8.3689905f-5 -3.52944f-5 9.301851f-5 -7.119659f-5 4.5898854f-5 -4.8175283f-5; 3.3281503f-5 -3.6557598f-5 7.830076f-6 -5.930152f-5 -1.18308935f-5 -9.351764f-5 0.00016619446 0.00016418408 4.786829f-5 5.7041307f-5 7.507997f-5 -1.6111753f-5 4.2820975f-5 7.940262f-5 -0.00013913632 -7.1899696f-5 2.83323f-5 8.798112f-5 -2.9677565f-6 -7.568716f-7 0.00014804084 -2.8079941f-5 -5.6881995f-6 0.000109942994 -0.00012063137 4.676091f-5 7.275616f-5 0.00011944945 0.00011646172 6.7134955f-5 -0.00012930905 0.00018892673; -4.5197055f-5 7.6218974f-5 -4.841563f-5 -4.633064f-5 2.2889928f-5 -0.00012572587 -6.5777625f-5 1.6553591f-5 0.00020211482 -0.00015833005 -2.350425f-5 -4.3017524f-5 -5.286671f-5 4.445738f-5 -3.957052f-6 -0.00015814314 0.000113502545 -3.6336445f-5 3.8445884f-5 0.0001321068 3.100041f-5 0.00016042832 -0.0001179981 6.094474f-5 0.00015125642 -9.37883f-5 7.3176176f-5 5.2412426f-5 -5.8152145f-5 9.120715f-5 -7.159329f-6 4.8692236f-5; -0.0001328779 -4.3363252f-5 -0.00010323852 -7.338743f-5 3.9123297f-6 -7.657276f-5 -2.603565f-5 -5.023191f-6 -4.8938135f-5 -8.537508f-5 0.00013527188 6.4307475f-5 -8.090411f-5 -0.00012418942 8.602142f-5 -2.8569259f-5 -0.000115219016 1.960304f-5 -0.00010999061 2.0743004f-5 0.00017862904 4.84189f-6 6.600437f-5 -0.00013849605 -0.00015473524 -3.09727f-5 -0.00015960235 -6.02756f-6 -3.5345383f-5 6.6446833f-6 -9.3897535f-5 0.0003320896; -0.000106382155 -3.5760095f-5 -0.00021926186 7.784457f-6 -1.501165f-5 0.00013351171 -0.0002190534 -0.00010459795 -9.931798f-5 -0.00014619823 -4.4387412f-5 2.7898679f-5 -0.00010170522 9.631811f-5 0.00016240298 -0.00012282513 -0.00011049247 8.824968f-5 -8.225247f-5 -0.00011373806 -0.00022956799 -0.0001361572 0.00017823956 2.7578533f-5 -0.00022154422 -9.780573f-5 -4.0669485f-5 8.882254f-6 8.0445185f-5 -0.00012375863 -7.6332486f-5 3.6121033f-5; 0.000106311985 -7.906432f-5 -1.1205585f-5 -1.5131243f-5 -2.301209f-5 9.798071f-6 -3.547471f-5 -0.00026120245 1.2586347f-5 -6.822268f-5 -4.4470264f-5 2.6175876f-5 -7.8987774f-5 -0.00012584275 0.000114576826 7.223032f-5 0.00021333872 -0.000119786804 0.0001256895 -0.00015838687 0.00022052763 -6.422682f-5 -9.331211f-6 0.0001569436 -0.0001671108 9.0518195f-5 2.7931892f-5 -6.773234f-5 6.8878166f-5 7.669723f-5 -7.55923f-5 -6.772959f-5; -0.00011581852 4.056334f-5 -7.584417f-5 6.682919f-5 -0.00014693369 -0.00016408226 -2.7297192f-5 1.0990538f-5 -0.00020758598 -0.0001952491 0.00010930451 -4.570805f-5 -4.3447973f-5 0.00011944701 5.714463f-5 -7.989301f-5 6.1266505f-6 -2.6666114f-5 -6.277608f-5 6.4315704f-5 0.00012569141 -9.230711f-5 3.3064796f-5 4.5073364f-5 -0.00023731914 -9.741791f-5 -3.673756f-5 -0.00012654746 -0.00018654528 8.457794f-5 -8.238985f-5 2.1891186f-5; 0.00012824725 6.230844f-5 -3.5760561f-7 -5.9993463f-6 -9.454958f-5 1.33715375f-5 1.2036471f-5 -3.8021746f-5 3.6707597f-6 -0.00014395555 -1.9633684f-5 -5.2429145f-6 -3.39631f-5 0.00014288817 3.8284277f-5 0.00020930635 0.000100583595 -8.1506885f-5 -9.2437986f-5 -7.1881045f-6 -7.4705626f-5 -1.2293477f-5 0.00010982288 -6.756013f-5 7.998698f-5 -5.2643347f-5 -5.301267f-5 -0.00010159299 -2.7576549f-5 -0.00015452213 -0.00013258905 0.00014465221; 0.00014644246 -2.7508222f-5 1.7154618f-5 5.141955f-5 -4.77926f-5 -1.9170036f-5 -0.00012884533 -3.879247f-5 7.27141f-5 4.1529744f-5 -7.674755f-5 5.46015f-5 5.1418017f-5 9.102186f-6 0.00011973223 -7.008698f-5 0.00011143945 -0.00014616331 -0.00017614207 -1.9930983f-5 -7.7717246f-5 0.00016771384 0.00010246564 4.0850435f-5 -1.4791896f-5 1.773296f-6 -6.777388f-5 -9.320183f-5 9.793747f-6 3.565856f-5 6.677867f-5 3.1949283f-5; -7.099837f-5 -3.880549f-5 -6.204006f-5 -0.00010500521 1.823292f-5 4.637754f-5 -3.7206995f-5 6.689372f-5 0.0002624164 6.272799f-5 -6.675406f-5 8.9881854f-5 2.916284f-5 -9.199561f-5 -3.960647f-5 -6.997204f-5 7.3515985f-5 -5.2350006f-6 -2.9981193f-5 7.657916f-5 -5.855929f-5 -6.207242f-5 1.6234853f-5 0.00018253861 -0.00018566557 -2.0411622f-5 0.00012313387 -6.8692752f-6 -3.3356002f-5 -4.484656f-5 5.2030642f-5 1.2304378f-5; -5.574587f-5 -7.176991f-5 0.00013603072 -0.000149423 -8.11134f-5 0.000103797676 -2.3516186f-5 4.6030716f-5 -1.7114471f-5 0.00014125546 -0.00018253592 6.564444f-5 -0.000121033656 -0.0002706364 5.073237f-5 -0.00013825523 4.2900978f-5 2.8626737f-5 -5.41647f-5 1.5618776f-5 -3.4429024f-5 -1.3398055f-5 7.332189f-5 0.000100944584 0.0001148984 -0.00013830149 1.4059937f-5 9.2065064f-5 -0.00014090656 0.00021189864 -9.320717f-5 0.00010519558; 0.00020013636 1.5572539f-6 -0.00015192667 -0.000110940586 6.263926f-5 5.9184727f-5 -1.3553948f-5 -0.00020229985 6.055178f-5 -4.4354092f-5 -0.00010978768 5.5385626f-5 1.6523543f-5 -3.5882742f-5 -0.00018012227 8.173728f-5 5.015054f-6 0.00017521388 9.933107f-5 0.00018757579 1.0158077f-5 7.643827f-5 -0.00012923847 -4.7052465f-5 -2.3069855f-5 4.7519363f-5 -4.298575f-5 -4.784051f-5 7.415187f-5 -0.00010853257 8.4092266f-5 9.127179f-5; 2.5515355f-5 -9.5564355f-5 -9.037943f-5 0.00012791662 -8.415058f-5 0.00014936812 3.0504425f-5 9.485063f-6 -2.9561497f-5 -0.00014902947 3.9537837f-5 -0.0001270693 -8.992534f-5 0.00011683574 0.00011981433 -2.0648791f-5 0.000155162 -3.621236f-5 -6.9292844f-5 6.620756f-5 4.0673425f-5 9.10418f-6 0.00011657249 0.0001450235 0.00013026349 -1.8015302f-5 -0.0002813654 -5.035065f-5 6.893402f-6 4.130788f-5 -6.730433f-5 0.00022247569], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[-8.983584f-5 -0.00013433404 -4.2122025f-5 -4.2126765f-5 -5.34522f-5 9.690064f-5 0.00011555653 7.3161355f-5 -2.2754948f-5 -0.0001903935 -4.5107634f-5 -8.0208636f-5 -9.7241624f-5 -4.576825f-5 -0.0001694684 -8.23452f-5 -0.00012939026 -0.00015752144 2.7088996f-5 0.00013749683 -0.000106728556 -9.9816214f-5 4.46386f-5 -0.00017696527 -1.9557518f-5 -1.43644675f-5 -0.00012814153 -0.00011666217 -7.3545394f-5 -0.0002074978 -0.00010953907 6.778721f-5; -0.00012386445 -3.7070597f-5 -7.233126f-5 -0.000102834245 -0.000108851644 0.00018463138 -1.1992287f-5 1.2755492f-6 -6.4162654f-5 6.285948f-5 -3.6387777f-5 9.828097f-6 -7.824389f-5 -0.00011592343 8.771088f-5 4.658816f-5 -0.00018579564 -9.645722f-5 -4.554248f-5 -2.8074031f-5 -9.886233f-5 -8.83781f-6 0.00012870507 -1.9690957f-5 -5.308007f-5 -4.6081364f-6 0.00012029303 9.976725f-5 0.00024238172 2.9957906f-5 0.00018049676 -1.5547093f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))
Similar to most deep learning frameworks, Lux defaults to using Float32
. However, in this case we need Float64
const params = ComponentArray(f64(ps))
const nn_model = StatefulLuxLayer(nn, nothing, st)
StatefulLuxLayer{Val{true}()}(
Chain(
layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
layer_2 = Dense(1 => 32, cos), # 64 parameters
layer_3 = Dense(32 => 32, cos), # 1_056 parameters
layer_4 = Dense(32 => 2), # 66 parameters
),
) # Total: 1_186 parameters,
# plus 0 states.
Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses
where,
function ODE_model(u, nn_params, t)
χ, ϕ = u
p, M, e = ode_model_params
# In this example we know that `st` is am empty NamedTuple hence we can safely ignore
# it, however, in general, we should use `st` to store the state of the neural network.
y = 1 .+ nn_model([first(u)], nn_params)
numer = (1 + e * cos(χ))^2
denom = M * (p^(3 / 2))
χ̇ = (numer / denom) * y[1]
ϕ̇ = (numer / denom) * y[2]
return [χ̇, ϕ̇]
end
Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
axislegend(
ax,
[[l1, s1], [l2, s2]],
["Waveform Data", "Waveform Neural Net (Untrained)"];
position=:lb,
)
fig
end
Setting Up for Training the Neural Network
Next, we define the objective (loss) function to be minimized when training the neural differential equations.
const mseloss = MSELoss()
function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
return mseloss(pred_waveform, waveform)
end
Warmup the loss function
loss(params)
0.0006340034930778956
Now let us define a callback function to store the loss over time
const losses = Float64[]
function callback(θ, l)
push!(losses, l)
@printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
return false
end
Training the Neural Network
Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
optprob,
BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
callback,
maxiters=1000,
)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-0.0001697314582995438; 0.00010228565952268363; 8.145117135412566e-6; 6.411446520359448e-5; 5.702866837957861e-5; 4.769039878741915e-5; -8.669598173571739e-5; 4.466077007233226e-5; 0.00012299527588749412; -0.000132572822621765; 2.750839121290538e-5; 9.581403719493089e-5; 6.930337258375225e-5; -1.6895068256411906e-5; 9.430357022218693e-5; 1.4941845620335468e-5; 1.629596954443344e-5; -1.7224563634942714e-5; 3.4461067116314995e-5; 7.41924086467686e-5; -0.00016886815137683687; 7.546447886844119e-5; 7.958376227188702e-5; 3.6587927752395104e-5; 0.00011942316632460914; -9.249538379646771e-6; -0.00015180307673277394; -3.257717253290365e-5; 0.0001192980125778297; 5.59488871657958e-5; 4.750215884993394e-5; -0.00012370357580935196;;], bias = [-2.9274388542120315e-16, 7.590827761542468e-17, 1.7592312048958112e-17, 7.970489487144603e-18, 7.704611506641097e-17, -4.553477879690255e-17, -1.330504381276346e-16, 4.3236082434421294e-17, 1.9517883019677685e-16, -3.0270683761579085e-16, -2.4400027054026977e-19, 1.219885007312881e-17, 7.759959809314279e-17, -6.0809451427638236e-18, -9.146086116332341e-17, 7.820296737396714e-18, 1.2620455729153506e-17, -1.3676579586897309e-17, 4.5250652974961455e-17, 1.813840310005523e-16, -1.4318615541780622e-16, 1.0003445740888312e-16, -1.085441371781708e-16, 1.3591050492109348e-17, 2.8327602418001855e-16, -1.9394710040410375e-17, -1.5266350451820516e-16, -8.09420856455454e-18, 1.1591263191948622e-16, 8.601234327872286e-17, 3.7470768113616866e-17, -1.5990522230349131e-16]), layer_3 = (weight = [-0.0001676979064084948 -2.7055235523076147e-5 -0.0001708220279850554 -0.00013386375762806317 -0.00012508594049890606 -2.2360509365029945e-5 4.8035975805903965e-5 -1.8344513637061668e-5 -2.3635688239517724e-5 0.00016299452722720438 -0.00010172550360025098 -7.416317935152723e-5 -3.980085425595764e-5 -8.333283449106197e-6 4.9076073946852116e-5 2.6625328200820528e-5 0.00013942646360206202 -5.548937956504737e-5 6.5078167563256e-5 -2.515240839641561e-6 -5.7717274148518804e-5 -2.330196371062163e-5 0.00010912325701625143 0.00015208761938097722 -0.00016733336638011551 -3.624912465042658e-5 4.903244366702504e-5 -0.0001356374469195045 0.00010465694681457658 -3.232381463917058e-5 2.7108066160649674e-5 5.4132806281051564e-5; 1.6384228021010965e-5 -0.00011740783045638093 5.4540437505482177e-5 -9.253972055261034e-5 -1.2284128187034873e-5 -6.452249434673501e-6 0.000127092110434441 0.0001856881765172745 -1.648495415679407e-5 2.3319950943230476e-5 -9.79648928284704e-5 8.780434038386141e-5 -6.639579330907284e-5 -1.104717264636961e-5 3.223134010227819e-5 0.0001096256866924411 -5.5522584060538365e-5 -1.6260090686712253e-5 1.3385523944877252e-5 0.0002391589270894307 -9.535555162769698e-6 0.0002191321593466104 -8.315605626079493e-5 -0.0001571767013763143 -3.855264256230693e-5 8.8281781446556e-5 0.00011161546464904088 7.483280680504936e-5 0.0001203282565450096 -1.684480483048503e-5 0.0001354646421725787 0.00017873055660061292; 0.00016121110947192656 0.00012128416837280901 0.00010605160186154678 -2.96198419245746e-5 -0.00011503954611631087 -9.620933870697405e-5 -0.00019094096079148176 3.076545127480437e-6 -8.02669511263275e-6 1.4728970555596302e-5 1.0511459905539458e-5 4.520043275173991e-5 4.558994386665603e-5 7.181252421200194e-5 -0.0002132668729327224 3.8723804565089846e-5 -0.00011552101805951529 4.792149174231023e-5 -5.470915367273663e-5 -0.00016613682838302005 6.794140189508905e-5 0.00016774966333381937 3.817372761754454e-5 0.00020025440734665798 -0.00015533850692823112 -0.00012101705088658346 -5.4406019088627896e-5 6.222878563770972e-5 5.5553363737132856e-5 -0.000194734543228112 -0.00018833127852208926 0.00010113643783465485; -7.903861820417401e-5 -0.00019461556053905908 -2.9832151775591086e-5 3.062505151425255e-5 -6.432218705267979e-6 0.00014819541706052354 8.415084470576414e-5 -0.00011026226253201511 6.129362865213893e-6 0.0001669051346117716 -5.5281245830344696e-6 -7.407519896186995e-6 -4.5201409883557274e-5 0.00010841836796422213 -5.251158996437854e-5 -7.990403515472355e-5 8.834010826936645e-5 -1.6194127364004185e-6 3.108535769675592e-5 -9.329522414085671e-5 -3.029239975043055e-5 5.894910667953799e-5 -2.5343802698298568e-5 0.00016575051289797495 1.6501135397594205e-5 -4.6059561790426266e-5 -2.0882965835733454e-5 -0.00010526345409638518 -6.352694595883716e-5 -9.468315671150816e-5 1.1019568083756388e-5 6.340286949231188e-5; -9.061359574779396e-5 0.00010970910974783022 -0.00011100117095276372 -3.3608259030299487e-6 1.1239251617268691e-5 -4.2726320461327854e-5 7.142397393244246e-5 0.00020074688094672295 -5.244126274128921e-5 -5.853854068782488e-5 0.00012446373315532148 0.00025077589984056584 -0.00013295957447490914 -2.69008652679214e-5 -0.0001395404162024611 2.8877995501429573e-5 8.221543415919301e-5 -7.810186631042153e-5 -0.00011476468913252978 -9.664517543507022e-5 7.88762716553778e-5 -2.1455631357922354e-5 -0.0001808311555208188 2.56015080964841e-5 -0.00013389517531259526 -2.0123561770847367e-5 7.242222491824792e-6 -1.3618207194044342e-5 -1.344395431537564e-6 -0.00018805593404914628 -3.24398080550178e-5 -0.00011941000971938016; 5.271884649575235e-5 3.213427473840392e-5 -0.00012450080546384275 -2.1593093401838108e-5 -0.00011957004905977914 -9.377022479144459e-7 5.1015777648921974e-5 6.329974414714021e-5 0.00010547638223672238 2.0020831220810025e-5 2.280712280937847e-5 -8.174158144809177e-5 -0.00016300096765460193 0.00012571669933552314 4.111412840779139e-6 3.1517782849757204e-5 -3.2269202415566695e-5 -2.0358428878328583e-5 2.638373255197323e-5 0.00010257692040343075 6.403088329990881e-5 5.401423616151556e-5 -0.00017931140907874124 5.20679720694049e-5 5.303554848999258e-6 -0.0001337431051751745 0.0001259466923557007 0.00011074341884980347 3.121964321055427e-5 2.361040671987766e-5 3.278787037290965e-5 2.0933230848661094e-5; -5.950247233408523e-5 0.0002837198842710325 7.65671354995589e-5 1.695371955726282e-5 -0.00013300841984595151 -8.086943872761855e-5 2.8507795716563396e-6 0.0002404714175893466 -5.975541736255832e-5 0.00019887705255974491 6.0198988455581697e-5 -5.852916384604966e-5 6.902229487174152e-5 8.792854331410623e-5 9.204525830433508e-5 -2.7586140470461506e-5 -9.237702053063415e-5 3.6550674655941695e-5 0.00016660579402675017 0.00012311544541013358 2.6235473148748187e-5 -2.883313956111743e-5 0.00010822461607609781 -0.00011626828219299112 -0.0001973281616417265 -2.595799943513563e-5 8.404890085080837e-5 -0.00017199835584754956 9.44393449322854e-6 -8.478372202510921e-5 -8.800451925047616e-5 -0.00010033644522331162; 0.00012411556808236034 -9.870271518786685e-7 -0.00011090414928314324 0.00013139717183965213 7.117669001179703e-5 3.117068395828806e-7 -9.591261341976973e-5 2.9992498135276784e-6 -1.8227113928716033e-5 2.5350098707875857e-5 0.0002689923551721028 5.3486490555386584e-5 2.8266067870007384e-5 9.696971237151835e-5 -0.00011567073825555892 -0.00010429220283678781 9.83559350921302e-5 -4.923969031934882e-6 -3.019107289712458e-5 -1.2781618084246728e-5 -9.576754537685794e-5 -0.00012653178574093663 -0.00019449385736645242 -0.00012360967385158925 0.00013195676573975909 -7.11744812228256e-5 -0.00010526785054518866 0.00010099338969173794 0.00014840214707381057 -8.721023341915616e-5 -5.425091972412077e-6 1.108836685905665e-5; 0.00010123228425448964 -4.118019783698266e-5 -4.78446730593415e-5 1.4108706313087383e-5 -7.274259256144838e-5 -1.9908932113944765e-5 0.00016979009673752358 8.742790818734147e-5 0.00010589746820562627 -0.00010506417161729828 -0.00012354745409164618 5.177623865275037e-6 -0.00020537980013239865 6.378916910881122e-7 -4.3093865638984256e-5 -6.61923462356635e-5 -2.3695271334769986e-5 -0.00015082596506171004 -0.00012527884096551697 -3.1433830696080454e-5 6.650140280192898e-5 2.3699206980488763e-5 0.00019240611586075656 -8.820837534211574e-5 -0.00012318209915599743 -9.800103576334286e-5 5.919808758711213e-5 -1.846535665744003e-5 5.4695557668423664e-5 -6.0265860480180503e-5 -6.075638371465656e-5 -9.641425127108462e-5; -9.730212986790164e-5 9.440717624663838e-6 -4.060237115221258e-5 -4.482095318702129e-5 -0.00013376872738911868 8.512411368689385e-5 4.20102849628775e-5 -9.076049099997738e-5 -1.3369478617397615e-5 -3.668054268667288e-5 -3.309285166063491e-5 8.497502931538059e-5 -1.6513398074615473e-5 6.265038319076257e-5 0.00010527607883238248 5.7474445204994e-5 -8.407243803697792e-5 7.229511359460498e-6 0.00011990765852066552 -0.00022197015568702533 -9.330661777960662e-5 -0.00013434784995946148 0.00011878324658288584 7.870422341772221e-5 -6.697710996116389e-5 0.00011379787684096798 1.2671113657116118e-5 0.00014010148952626483 -5.089140831581197e-5 -4.8552842779039466e-5 4.298851563623059e-5 -7.00470547573161e-5; -0.00010720503341502909 -0.00013930166896530504 7.045184902748072e-5 -3.976643056008646e-5 5.87292554010103e-5 0.00011556381942140103 -6.8480065201209015e-6 0.00010907991527713977 7.881257916237542e-5 -0.00012434305681012985 3.568841024421052e-6 -0.0001335087115839608 3.1333841619961006e-5 -0.00012380393745476387 -3.2570384788375596e-5 5.01236201708111e-5 -2.9731477112332414e-5 -0.00013740925059752544 -4.2933972028062625e-7 -0.00010447585081779895 -2.5110701968459927e-5 -8.786834659618805e-5 -1.7566974711006763e-6 -9.611086424771609e-5 9.540646279133165e-5 -0.00013116129848273495 2.4380626237467002e-6 -2.3950830651269606e-5 -1.82002957407827e-5 -2.1176791775326684e-5 0.00011619489232910962 -0.00015725117991310215; -7.356652782011137e-6 -4.069424746862949e-5 -0.00018162408948716228 -0.00014059700176890984 -0.00016907261149332053 4.918687202429172e-5 -8.071727486399633e-5 -5.637065830829808e-5 4.459524436665831e-5 -0.00013532351865737317 -0.00012732671737045428 -1.7017399182217188e-5 9.331401245130253e-5 -1.985903898036069e-5 3.3766149249137105e-5 0.00018180356139476907 0.0001581406774900066 -5.1187367967362116e-5 -2.5274821632922264e-5 -9.423229337276793e-5 -7.964621752337936e-5 6.655413901593654e-5 6.495559977923463e-6 -4.275272137907655e-5 -1.3893203936593613e-5 -0.00012997250294153096 -8.17868988409266e-5 8.860343380059698e-5 -4.019135874418498e-5 -0.000134514461274521 7.590156903883213e-5 -0.00010434108687636292; -0.00012957022278481273 0.00016967581308068865 2.395534778275769e-5 8.874292092019617e-5 -1.2920141356868273e-6 0.00017668218191334832 -7.923622485162644e-5 -1.9376344940100472e-5 -0.0001641722906837766 -6.885865846409526e-6 0.0001358830353952366 4.8924584402415127e-5 -0.00016490596914576418 0.00010131964648695446 -0.00018364943258842355 -0.00014845841957634531 -0.0003314704795824945 1.970044052955945e-5 -7.414208681933766e-5 0.0001307424258216654 -1.1824490289135184e-5 0.00016250186439225432 5.239968362235469e-5 4.263382782897735e-5 2.8284788126771472e-5 -9.032932445319593e-5 -6.10674456451586e-5 -6.104459186230043e-5 -0.00011409501036308414 -2.797410758415815e-5 7.091210266895448e-5 6.465392783372061e-5; 4.233770735781268e-5 4.609585588085418e-5 -1.5277032131705628e-5 -1.710903431300207e-6 -2.0596017263157355e-5 0.00011160914576365293 4.0227756047247384e-5 -2.9087393311030647e-6 2.772460678287423e-5 6.371508200871946e-5 -0.0002645847184491846 -5.287665268653162e-5 1.518154070552902e-5 2.8751986549907247e-5 -3.739205427947554e-5 -0.0001063924123324958 -0.00013661065028529068 -5.082139474936214e-5 -0.00013373264525100013 3.558283927603413e-5 -6.201682451053379e-6 -9.249224828337525e-5 -7.999838859788327e-5 0.00011307451635119156 -9.400337009238931e-5 -3.380060522132944e-5 4.2380573662094726e-5 1.7342919393997773e-5 6.304491538074877e-5 -2.7546972402630487e-5 5.440141606712395e-5 0.0001374364306077838; -9.65144257541694e-6 -0.00013605177567704136 3.196113460511221e-6 3.3061619127258874e-5 8.982302470237567e-7 -2.590178458965696e-5 3.356195762855509e-5 -2.1114324532824843e-5 -1.9376289252686604e-5 6.910450125680597e-5 -0.0001257968809445417 -4.179407672971017e-5 7.829639313378933e-5 -2.2851357549804175e-5 0.00010480989536838151 -0.00013302732016789068 6.329773410791262e-5 -0.00012853213169044153 0.0001261654275945359 9.892272888784843e-5 -2.6425792261608902e-6 7.035586410304144e-5 -6.326711930931016e-5 0.0001558885002518975 0.0001697365393406655 -4.324908635360455e-5 4.976423102881934e-5 -0.00012985314000295758 9.487998855867787e-5 2.3130569992211125e-5 5.558103536124961e-5 -0.00011672601734651845; 9.68674804610438e-5 -0.00010667541107952543 4.3607616244366155e-5 5.1577387348766293e-5 0.00014341929010722656 0.0001289633884155362 -6.42172952936778e-5 -2.6474369296670644e-5 -4.862532194212841e-5 -8.432544462241003e-5 0.00012724082000229313 -6.624274082899594e-6 -4.7104170225542946e-5 -2.443985875263549e-5 -0.0001282660586000233 -0.00022639689894251734 -1.4353154146852622e-5 0.00016366007286730948 3.027008380765845e-5 5.0350529927778265e-5 -2.4752957760681913e-5 -0.00014802313569791685 3.4853322286176906e-5 2.8173198297985918e-5 7.569755063766763e-5 -5.240119105994911e-5 -5.2773461793294215e-5 -0.00013845165711220543 2.2027697123374402e-5 4.1536765930027305e-5 -4.2336850019387596e-5 -4.702098420213735e-5; 4.780213187280861e-5 4.90729324398135e-5 0.00010009932505149166 -0.00017751224747676754 -6.019324080021449e-5 -1.0640239609853082e-5 0.00017997145535140772 3.5244341476244322e-6 -6.708091329834185e-5 3.093187826983795e-5 7.942219583732757e-5 1.0851784124171386e-5 0.0001171893416782431 6.30942266312963e-5 -0.000146789356379812 4.6811964990997194e-5 -4.2137941066680975e-5 6.62482604315693e-7 2.4407536172151713e-5 -1.6656393357477394e-5 2.0508874355650657e-5 -5.847795925852314e-5 -1.6187728918664654e-5 5.075296924282527e-5 0.0001127747489832834 0.00015356700516953837 9.741902961344456e-5 -0.00010747427966030574 -1.951522442446755e-5 -0.0001749508630140158 6.147526785447898e-5 -5.490387427322332e-5; -0.0001284200607435162 4.245400302607816e-5 0.00020506705365989022 -2.180427147976592e-5 0.00011909830388759752 7.375559104986082e-5 -0.0001090628848280071 -0.0002096523876636515 9.863793272671384e-5 -9.999952274228434e-5 -3.674894653812408e-5 0.00013129545002158092 -8.762677137127639e-5 -2.130849318490239e-5 -3.2920770561021515e-5 -9.393553252435421e-6 8.755007181463424e-5 4.9509975699778956e-5 -0.00014449903940402342 -1.869204063694974e-5 1.863111171842666e-5 0.0001416883260142751 -3.966165425225703e-5 4.392174921378116e-5 -4.889786034908954e-5 0.00016517386572815142 5.4699396034998984e-5 -0.00012118710403882335 9.229289364892832e-6 5.688344569790673e-5 0.00013832371948666462 2.0801398907504925e-5; -0.00011982684224553296 -0.00024035610692286742 -1.611444383222607e-5 -1.5910898689192016e-7 -0.00018649018454077333 0.00015725351261728667 1.692782792027397e-5 -4.8267813661981254e-5 -0.00013262337449183097 -6.205617534638475e-5 -5.260899678137366e-6 -4.755263432222427e-5 -1.4386012816130966e-5 3.044704647556903e-5 0.0001869931162503544 4.8593837738121116e-5 2.68175350641508e-5 1.6501043893480932e-5 3.707626604530144e-5 -0.00024769574371809517 -7.537243652099663e-5 -9.850746152274652e-7 7.876166058061348e-5 -0.00013836106369915287 -2.8954469934124113e-5 2.4833530579893895e-5 2.329311576554545e-5 0.00014605204490301298 -7.029760160427038e-5 -9.098082079496855e-5 -3.4686009138270835e-5 -0.0001951067518972731; -2.0245865575379072e-5 -3.9818599011173445e-5 -3.642454668310547e-5 -6.976813584928156e-5 2.7557925373420012e-5 -4.205847345887585e-5 5.08157203380163e-5 1.8999107292490846e-5 -2.5514875791969697e-5 -1.9110490588472134e-5 4.850790663609259e-5 -9.211514099164218e-5 2.4193593513209574e-5 -0.00013067988458219043 -2.223618740985012e-5 -6.386972073827934e-5 -0.00014642681764102602 -0.00012086462503661459 -2.5005926194847432e-5 7.16643074933568e-5 -6.775597389168415e-6 0.00012746733717491289 -0.00019612799643667925 -7.967021212111202e-5 -6.54264409736532e-5 2.0915630907853313e-5 8.368747092888573e-5 -3.5296835098545974e-5 9.301607804942993e-5 -7.119902567368766e-5 4.5896419929211076e-5 -4.817771752144009e-5; 3.3284460070926866e-5 -3.6554641272005315e-5 7.83303300728294e-6 -5.929856445413273e-5 -1.1827936875981871e-5 -9.351468645740311e-5 0.00016619741188356134 0.0001641870356909373 4.7871245326172994e-5 5.7044263369100086e-5 7.508293026354299e-5 -1.610879656050105e-5 4.282393171911802e-5 7.940557668212984e-5 -0.00013913336181553735 -7.189673954393526e-5 2.833525729534866e-5 8.798407622840437e-5 -2.964799821475387e-6 -7.539149489684968e-7 0.00014804379577274663 -2.8076984679269312e-5 -5.6852428567671245e-6 0.00010994595038233406 -0.00012062841153943225 4.676386640518174e-5 7.275911309283059e-5 0.00011945240373484899 0.00011646467725972403 6.713791196694507e-5 -0.00012930608966393409 0.00018892968625754703; -4.519633651818537e-5 7.621969276932115e-5 -4.8414910938494694e-5 -4.632991980053287e-5 2.289064692866741e-5 -0.00012572515109660797 -6.577690649884637e-5 1.6554310118758738e-5 0.0002021155437117902 -0.0001583293283462071 -2.3503530914799332e-5 -4.301680522495803e-5 -5.2865992560887745e-5 4.445810051528095e-5 -3.9563329070157774e-6 -0.00015814242354700986 0.0001135032639152865 -3.633572626732932e-5 3.844660247103579e-5 0.00013210751646091393 3.100112959635062e-5 0.0001604290359549706 -0.00011799737834245753 6.0945460486787224e-5 0.00015125713840951165 -9.378758253170816e-5 7.317689457102428e-5 5.241314508146878e-5 -5.8151426432794376e-5 9.120786714027856e-5 -7.1586100917692785e-6 4.869295522929101e-5; -0.00013288030020896095 -4.336564887319958e-5 -0.00010324091561610035 -7.33898297604716e-5 3.909933018676027e-6 -7.657515669247118e-5 -2.603804676449426e-5 -5.02558776328577e-6 -4.894053125292893e-5 -8.537747924686116e-5 0.00013526948607055366 6.430507784704309e-5 -8.090650515636212e-5 -0.00012419181728741296 8.601902374527025e-5 -2.857165525905932e-5 -0.00011522141283616564 1.9600642934035064e-5 -0.00010999300428206167 2.0740607231220162e-5 0.00017862664100710529 4.839493447308748e-6 6.600197668182071e-5 -0.0001384984454326896 -0.00015473764067986014 -3.0975096871900984e-5 -0.00015960474700670894 -6.029956859622229e-6 -3.534778005098894e-5 6.642286569203408e-6 -9.389993209164573e-5 0.00033208720998834004; -0.00010638706404276643 -3.576500417694156e-5 -0.00021926676523437256 7.779547595404621e-6 -1.5016559664233983e-5 0.00013350680277652 -0.0002190583090487141 -0.00010460286099261753 -9.932288772688127e-5 -0.00014620314137572837 -4.439232176084423e-5 2.7893769620271327e-5 -0.00010171012942011162 9.631319884645323e-5 0.0001623980734075237 -0.00012283004431085163 -0.00011049737619034447 8.824477303276343e-5 -8.225737652539556e-5 -0.00011374296987925237 -0.00022957289726034303 -0.00013616211614136134 0.00017823464626416183 2.75736238472557e-5 -0.0002215491312745122 -9.781063814445858e-5 -4.0674394731451664e-5 8.877344989722363e-6 8.044027530559314e-5 -0.00012376353512083145 -7.633739547602008e-5 3.611612377426737e-5; 0.00010631150765313701 -7.906479474421011e-5 -1.1206062635195462e-5 -1.5131720982920231e-5 -2.3012567152348332e-5 9.79759336679927e-6 -3.547518751382294e-5 -0.0002612029322053633 1.2585869472804024e-5 -6.822315951374599e-5 -4.447074131256473e-5 2.6175398335015413e-5 -7.898825167010977e-5 -0.0001258432286422099 0.00011457634794261863 7.222984534657048e-5 0.00021333823936494487 -0.00011978728177289989 0.00012568902714612395 -0.00015838735013765714 0.00022052714742223662 -6.42272981794873e-5 -9.33168864228946e-6 0.0001569431213565773 -0.00016711128152472257 9.051771742726163e-5 2.7931414325412684e-5 -6.773281817821352e-5 6.887768798243892e-5 7.669675217677373e-5 -7.55927805302992e-5 -6.773006786623314e-5; -0.0001158225299768796 4.055932871047841e-5 -7.584817873799246e-5 6.682517917574912e-5 -0.00014693769705268142 -0.00016408627058631456 -2.7301203288556553e-5 1.0986526405711353e-5 -0.00020758999241300743 -0.00019525311151303846 0.00010930049532539297 -4.571206088079807e-5 -4.3451984736792144e-5 0.0001194429983406237 5.714061765680316e-5 -7.989702359559928e-5 6.122639235350279e-6 -2.6670124923889274e-5 -6.278009320163026e-5 6.431169234197366e-5 0.00012568739978813584 -9.231112423371719e-5 3.306078500095644e-5 4.506935286549082e-5 -0.00023732314954762345 -9.742192420895213e-5 -3.6741572756019925e-5 -0.00012655146789989307 -0.00018654929087282162 8.457393190105846e-5 -8.239386303835265e-5 2.1887175109411128e-5; 0.00012824677826507632 6.230796855312166e-5 -3.580786178583016e-7 -5.999819288938364e-6 -9.455005062840665e-5 1.3371064530708483e-5 1.2035998158588089e-5 -3.8022219153156434e-5 3.670286721225638e-6 -0.00014395602199578216 -1.9634157136118162e-5 -5.2433875432521624e-6 -3.396357358064264e-5 0.0001428876930069754 3.82838042489842e-5 0.00020930587918633763 0.00010058312175465914 -8.150735818411989e-5 -9.243845858940674e-5 -7.188577522121906e-6 -7.470609859659656e-5 -1.2293950087716612e-5 0.00010982240921953954 -6.756060617839734e-5 7.998650828745602e-5 -5.264381986557342e-5 -5.3013143836086964e-5 -0.0001015934613612835 -2.757702174925791e-5 -0.0001545226042005585 -0.00013258952101090726 0.0001446517343786213; 0.00014644285057108442 -2.750782822879661e-5 1.7155011853514888e-5 5.1419942799451875e-5 -4.779220670791414e-5 -1.9169642648581058e-5 -0.0001288449411662681 -3.879207633183391e-5 7.27144952248387e-5 4.1530137430138176e-5 -7.67471565199461e-5 5.460189368711637e-5 5.1418411210373994e-5 9.102579504155707e-6 0.00011973262099452173 -7.008658484885501e-5 0.00011143984102779966 -0.0001461629135450724 -0.00017614167790491178 -1.993058958071383e-5 -7.771685249502238e-5 0.00016771422907147453 0.00010246603142838815 4.085082856142849e-5 -1.4791502501272871e-5 1.773689866763311e-6 -6.777348516373292e-5 -9.320143635015977e-5 9.794141085487985e-6 3.5658952590505835e-5 6.677906731745171e-5 3.194967667475115e-5; -7.099811836770667e-5 -3.8805236450124154e-5 -6.203980655801171e-5 -0.00010500495472616239 1.8233174436765776e-5 4.637779357334893e-5 -3.720674130409707e-5 6.689397496710827e-5 0.00026241665846509106 6.272824185042806e-5 -6.675380490879222e-5 8.98821083906037e-5 2.9163094317929743e-5 -9.199535333710413e-5 -3.9606217520039016e-5 -6.997178633856277e-5 7.351623872136076e-5 -5.23474654574569e-6 -2.998093868473708e-5 7.657941687693155e-5 -5.8559035729113264e-5 -6.207216274152245e-5 1.6235107345429972e-5 0.0001825388674701501 -0.00018566531861825654 -2.0411368218908694e-5 0.00012313412235424973 -6.869021213340336e-6 -3.335574792391427e-5 -4.484630762409133e-5 5.2030896135413225e-5 1.2304632313763221e-5; -5.574660665236069e-5 -7.17706445569369e-5 0.00013602998458433536 -0.0001494237314758103 -8.111413828676847e-5 0.00010379693938072468 -2.351692294117239e-5 4.602997958196948e-5 -1.7115208014119314e-5 0.0001412547187798477 -0.0001825366582337056 6.564370212497056e-5 -0.00012103439278083164 -0.0002706371398956736 5.073163427065915e-5 -0.00013825596599720773 4.290024096614328e-5 2.8625999927195916e-5 -5.416543556141998e-5 1.5618039266581264e-5 -3.442976069284266e-5 -1.339879165546491e-5 7.332115435626152e-5 0.00010094384722530514 0.00011489766418587405 -0.00013830222653571876 1.4059199908545418e-5 9.206432676061654e-5 -0.0001409072958479823 0.00021189790802661045 -9.320790761497175e-5 0.00010519484001331045; 0.00020013679282192796 1.5576857581517823e-6 -0.0001519262342721862 -0.00011094015385105198 6.263968939812013e-5 5.9185159310354944e-5 -1.3553515733738619e-5 -0.00020229942051309052 6.055221352063039e-5 -4.435366022016038e-5 -0.00010978724926325368 5.538605805987497e-5 1.6523975185473183e-5 -3.5882310063688625e-5 -0.00018012183604377316 8.173771025532056e-5 5.015485706745102e-6 0.00017521430968691726 9.933149963536417e-5 0.00018757621988107508 1.0158509079252423e-5 7.643870308448613e-5 -0.00012923803764942058 -4.705203277489054e-5 -2.3069423283679167e-5 4.751979472849653e-5 -4.298531816538619e-5 -4.784007719216748e-5 7.415229888784631e-5 -0.00010853213929884976 8.409269762843018e-5 9.127221970039337e-5; 2.5516464021286785e-5 -9.556324685841893e-5 -9.037831942701429e-5 0.00012791772577760876 -8.414947397691961e-5 0.00014936922788469324 3.0505533587657972e-5 9.486171521660677e-6 -2.9560388710580608e-5 -0.0001490283585634304 3.9538946012857155e-5 -0.000127068194259536 -8.992423418827244e-5 0.0001168368497848352 0.00011981543765718079 -2.0647682603208374e-5 0.0001551631020366987 -3.621125243521953e-5 -6.929173569811266e-5 6.620867033338259e-5 4.067453382152315e-5 9.105288782977357e-6 0.00011657359836239424 0.00014502460807411236 0.00013026460045797517 -1.8014193400500314e-5 -0.0002813643051093686 -5.0349540387156736e-5 6.894510435030118e-6 4.1308990049531546e-5 -6.730322375811432e-5 0.00022247679852729397], bias = [-1.3733823340719919e-9, 3.162514320116034e-9, -1.0901284972059797e-9, -3.5630259012427624e-10, -2.115629572780321e-9, 8.992401736699387e-10, 1.1639597374330533e-9, -3.06549021293436e-11, -2.2904313365509043e-9, -5.958983225092548e-10, -2.5582230280102997e-9, -3.047057692795361e-9, -1.1046408325676143e-9, -1.009963596452525e-9, 6.427067305530308e-10, -7.629566660824221e-10, 9.348028875423922e-10, 1.0131523724990055e-9, -3.083522347777203e-9, -2.434366746692325e-9, 2.956651097403858e-9, 7.189240801820333e-10, -2.3966841974604175e-9, -4.909416041429036e-9, -4.776384081075901e-10, -4.011297328537408e-9, -4.730034044523222e-10, 3.9381809505091595e-10, 2.54025956011018e-10, -7.368938288333194e-10, 4.3187675201878646e-10, 1.1085750377410346e-9]), layer_4 = (weight = [-0.0007267745307670983 -0.0007712724595547749 -0.0006790607371663517 -0.0006790655111788249 -0.0006903908086572544 -0.0005400380816642183 -0.0005213821761073976 -0.0005637773956039252 -0.0006596935327217462 -0.0008273322423735125 -0.0006820461748719759 -0.0007171470830603636 -0.0007341803342985298 -0.0006827069687803133 -0.0008064071393200391 -0.000719283930213656 -0.0007663289824445784 -0.0007944601571094703 -0.000609849459363555 -0.0004994417468032814 -0.0007436670220082809 -0.0007367549472505669 -0.0005922999684063296 -0.0008139032084211448 -0.0006564962605432561 -0.0006513027071745497 -0.0007650802741807144 -0.0007536009181834928 -0.0007104841420930354 -0.0008444365279160297 -0.0007464778127713534 -0.00056915150395021; 0.00011277221949219937 0.00019956596981520085 0.00016430541284245484 0.00013380244277758787 0.00012778499282884508 0.0004212680623982405 0.00022464438706211012 0.00023791223830460493 0.0001724739740009038 0.0002994961635632207 0.00020024883461924206 0.0002464646739662078 0.00015839278710159843 0.0001207132465150962 0.0003243475671543708 0.00028322484240945274 5.084104103489049e-5 0.0001401794592619962 0.00019109409843261473 0.00020856259239087976 0.000137774255318966 0.00022779887275292844 0.00036534169101168644 0.00021694543044775946 0.00018355661483034384 0.00023202836378451 0.00035692971524230145 0.0003364039340986557 0.00047901840881097644 0.0002665945882296231 0.00041713344827879357 0.00022108958178661747], bias = [-0.0006369387501540383, 0.00023663668916275781]))
Visualizing the Results
Let us now plot the loss over time
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")
lines!(ax, losses; linewidth=4, alpha=0.75)
scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)
fig
end
Finally let us visualize the results
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
s3 = scatter!(
ax,
tsteps,
waveform_nn_trained;
marker=:circle,
alpha=0.5,
strokewidth=2,
markersize=12,
)
axislegend(
ax,
[[l1, s1], [l2, s2], [l3, s3]],
["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
position=:lb,
)
fig
end
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 128 × AMD EPYC 7502 32-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 128 default, 0 interactive, 64 GC (on 128 virtual cores)
Environment:
JULIA_CPU_THREADS = 128
JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 128
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.