Training a Neural ODE to Model Gravitational Waveforms
This code is adapted from Astroinformatics/ScientificMachineLearning
The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl
Package Imports
using Lux, ComponentArrays, LineSearches, OrdinaryDiffEq, Optimization, OptimizationOptimJL,
Printf, Random, SciMLSensitivity
using CairoMakieDefine some Utility Functions
Tip
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector
function one2two(path, m₁, m₂)
M = m₁ + m₂
r₁ = m₂ / M .* path
r₂ = -m₁ / M .* path
return r₁, r₂
endone2two (generic function with 1 method)Next we define a function to perform the change of variables:
@views function soln2orbit(soln, model_params=nothing)
@assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
if size(soln, 1) == 2
χ = soln[1, :]
ϕ = soln[2, :]
@assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2"
p, M, e = model_params
else
χ = soln[1, :]
ϕ = soln[2, :]
p = soln[3, :]
e = soln[4, :]
end
r = p ./ (1 .+ e .* cos.(χ))
x = r .* cos.(ϕ)
y = r .* sin.(ϕ)
orbit = vcat(x', y')
return orbit
endsoln2orbit (generic function with 2 methods)This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d_dt(v::AbstractVector, dt)
a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
b = (v[3:end] .- v[1:(end - 2)]) / 2
c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
return [a; b; c] / dt
endd_dt (generic function with 1 method)This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d2_dt2(v::AbstractVector, dt)
a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
return [a; b; c] / (dt^2)
endd2_dt2 (generic function with 1 method)Now we define a function to compute the trace-free moment tensor from the orbit
function orbit2tensor(orbit, component, mass=1)
x = orbit[1, :]
y = orbit[2, :]
Ixx = x .^ 2
Iyy = y .^ 2
Ixy = x .* y
trace = Ixx .+ Iyy
if component[1] == 1 && component[2] == 1
tmp = Ixx .- trace ./ 3
elseif component[1] == 2 && component[2] == 2
tmp = Iyy .- trace ./ 3
else
tmp = Ixy
end
return mass .* tmp
end
function h_22_quadrupole_components(dt, orbit, component, mass=1)
mtensor = orbit2tensor(orbit, component, mass)
mtensor_ddot = d2_dt2(mtensor, dt)
return 2 * mtensor_ddot
end
function h_22_quadrupole(dt, orbit, mass=1)
h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
return h11, h12, h22
end
function h_22_strain_one_body(dt::T, orbit) where {T}
h11, h12, h22 = h_22_quadrupole(dt, orbit)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
h11 = h11_1 + h11_2
h12 = h12_1 + h12_2
h22 = h22_1 + h22_2
return h11, h12, h22
end
function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
# compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2
@assert abs(mass1 + mass2 - 1.0)<1e-12 "Masses do not sum to unity"
h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
@assert mass_ratio≤1 "mass_ratio must be <= 1"
@assert mass_ratio≥0 "mass_ratio must be non-negative"
orbit = soln2orbit(soln, model_params)
if mass_ratio > 0
m₂ = inv(T(1) + mass_ratio)
m₁ = mass_ratio * m₂
orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
return waveform
endcompute_waveform (generic function with 2 methods)Simulating the True Model
RelativisticOrbitModel defines system of odes which describes motion of point like particle in schwarzschild background, uses
where,
function RelativisticOrbitModel(u, (p, M, e), t)
χ, ϕ = u
numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
denom = sqrt((p - 2)^2 - 4 * e^2)
χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = numer / (M * (p^(3 / 2)) * denom)
return [χ̇, ϕ̇]
end
mass_ratio = 0.0 # test particle
u0 = Float64[π, 0.0] # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4) # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize) # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, eLet's simulate the true model and plot the results using OrdinaryDiffEq.jl
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)
axislegend(ax, [[l, s]], ["Waveform Data"])
fig
endDefiing a Neural Network Model
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,
const nn = Chain(Base.Fix1(broadcast, cos),
Dense(1 => 32, cos; init_weight=truncated_normal(; std=1e-4)),
Dense(32 => 32, cos; init_weight=truncated_normal(; std=1e-4)),
Dense(32 => 2; init_weight=truncated_normal(; std=1e-4)))
ps, st = Lux.setup(Xoshiro(), nn)((layer_1 = NamedTuple(), layer_2 = (weight = Float32[0.00013486664; -6.086827f-5; -2.1815624f-5; -0.000101836085; 9.51832f-6; 9.755264f-5; -8.0275844f-5; 0.00016536054; -0.00012283979; -3.113165f-5; -1.5506976f-5; 0.00014282213; -4.820003f-5; -9.392016f-6; -2.253255f-5; 5.5018765f-5; -5.4944874f-5; 2.9550167f-5; -0.00018670462; -9.627539f-5; -0.00012719716; 0.0001672438; 0.000109665416; -0.00014522955; 7.6039825f-5; -6.387787f-5; -0.000118554184; 8.658944f-5; 0.0001596787; -9.286083f-5; 4.985584f-5; 9.984686f-5;;], bias = Float32[0.59618795, -0.6482552, -0.23217428, -0.45540428, -0.89014816, -0.19670737, 0.7987876, 0.42849612, -0.4005661, 0.15956736, -0.010939121, -0.034234405, -0.13895881, 0.7470608, -0.33783185, -0.9743047, -0.12773585, -0.5957316, -0.9185462, 0.32255816, 0.38599932, 0.46897388, -0.08547068, -0.62477446, 0.31575274, 0.8093822, 0.57232594, 0.91107714, 0.32292485, -0.510834, 0.34849048, -0.2196194]), layer_3 = (weight = Float32[-8.9265915f-5 0.00012714938 1.3552497f-6 -9.066618f-5 0.00010753438 -4.8836344f-5 -5.0133774f-5 -0.00010982102 -1.392472f-5 -0.00014842844 -4.400101f-5 0.00010470117 -6.111923f-5 -2.318335f-5 3.6405752f-5 -3.858367f-5 4.759052f-5 -9.343398f-5 0.00012573389 -0.0002273207 4.7403668f-5 2.401994f-5 -5.7272515f-5 -5.6557255f-6 3.453421f-5 -7.92658f-5 2.5768559f-5 -6.5919025f-6 6.562397f-5 3.6866848f-5 4.56753f-6 6.3572596f-5; -0.00026701193 -0.00014138964 -0.00013708665 -0.00011134131 2.839011f-5 -1.4562703f-5 -0.0001298136 -2.7648442f-5 -8.6266555f-5 0.00013751947 -1.9721198f-5 5.3618305f-6 -0.00025962957 8.5864194f-5 0.000107604785 2.637031f-7 5.8903996f-5 -3.43967f-5 0.00010941115 4.674156f-6 2.262793f-6 -0.00019473747 -3.777692f-5 -0.00011015411 -3.6522808f-5 -0.00025190655 2.4693616f-5 -3.318374f-5 0.000193765 -0.00012616688 4.1610943f-5 -7.0785354f-5; -7.617572f-5 -1.5221441f-5 -3.6111076f-5 0.00012157453 -6.654173f-5 1.1459002f-5 4.6367622f-5 -3.4537286f-6 -0.00010836191 6.0174025f-5 1.9261093f-5 -0.00014976203 5.1565505f-5 9.130991f-6 0.00015441085 1.8040638f-6 2.311486f-5 -4.9094164f-5 -5.1385043f-5 -2.7415055f-5 -5.124019f-5 -9.126013f-5 -0.00022966247 -0.00015601667 1.157076f-6 -8.713528f-5 9.481494f-5 -0.00014354732 -0.00012639361 -0.00018863505 -1.4192917f-6 8.0514605f-5; -6.4259344f-5 -3.6979323f-5 0.00018758276 -7.3173064f-6 1.3022077f-5 -1.8890105f-5 -3.1442545f-5 -7.212687f-5 -7.352473f-5 -9.973172f-5 -9.680505f-5 -2.0598107f-5 -0.000103381004 -4.4680324f-5 0.0001149187 3.6679598f-5 0.00014085267 2.9614412f-5 -5.5192006f-5 -8.906838f-5 -0.00010248077 -3.5520272f-5 -5.0916915f-5 -6.7159555f-5 -8.103359f-5 0.00014674455 -9.476528f-5 0.00021752363 0.00018389341 -7.4100906f-5 0.00012255642 7.94781f-5; -6.738597f-5 -8.931153f-5 -0.00013046402 5.0647093f-5 9.5994f-5 3.0536088f-5 -2.8045442f-5 6.215293f-5 -2.640658f-5 -5.1844796f-5 2.4426165f-5 0.00017085944 0.0001904481 9.1379734f-5 -5.9005397f-5 0.0001844943 4.0156963f-5 2.5515941f-5 -2.2844732f-5 2.5007734f-5 -1.895668f-5 6.031763f-5 5.3380103f-5 0.00012101602 8.964602f-5 0.00021779032 -4.2463565f-5 -0.00017798686 5.7310262f-5 9.0245834f-5 5.6403373f-5 0.00012905433; -1.026378f-5 -7.6730015f-5 -3.9467202f-5 0.00019675812 0.00015969122 3.0435562f-5 -0.00021834402 -5.553643f-5 7.705604f-6 -4.745324f-5 -0.00017646217 -0.00013773568 2.3522978f-5 7.1317656f-5 4.2356118f-5 8.168245f-5 6.8279136f-5 -0.00011673913 1.3830483f-5 -4.991736f-5 -0.00015320856 -5.6589903f-5 -6.2069266f-5 4.881788f-5 1.0390524f-6 5.6706153f-6 -4.488635f-5 -5.891834f-5 7.3154944f-5 0.00020035537 -3.9386814f-5 -7.20184f-5; 0.00010926268 -1.7412733f-5 -2.7980333f-5 -9.99285f-5 0.0002475101 7.910591f-5 5.4693304f-5 6.685949f-8 0.0001041361 2.5406202f-6 -6.236749f-6 -1.5909134f-5 2.346214f-5 1.2387029f-5 -0.00026757305 -1.3177067f-5 9.0888854f-5 -0.00015278098 7.030128f-5 -2.4672192f-5 9.935076f-5 4.875598f-5 -8.3597915f-5 -0.000267078 -0.00010592828 -0.0001542098 1.7652557f-5 -3.1884243f-5 0.00011062164 -5.8849007f-5 -0.00012774026 2.2638607f-5; -1.8954292f-5 3.2556785f-5 3.5873214f-5 -5.7337664f-5 0.00010665116 -0.00012800243 -0.00015707419 -6.199926f-5 5.5246233f-5 2.2884546f-5 0.00010430311 -6.045283f-6 -6.281782f-5 -7.578203f-5 -0.00010418124 -0.000138678 -2.3946366f-5 -6.7014495f-5 5.8415517f-5 -2.6106134f-5 1.7852362f-5 -2.2024664f-5 7.992082f-5 -0.00010937178 2.762842f-5 -5.299589f-6 -0.00010500459 -0.00010136794 -4.995678f-5 -5.522587f-5 -1.1050815f-5 -7.702721f-5; -3.8693393f-5 6.0481143f-5 -0.00010209737 0.00018429513 -2.235241f-5 3.4792538f-5 -6.912488f-5 4.7011443f-5 -1.16442425f-5 3.805059f-5 -2.365294f-5 -7.3903975f-5 -7.32258f-6 -4.0900788f-5 0.0001601519 1.0778848f-5 -0.00013308409 2.8580494f-5 0.00013405178 7.492811f-5 -0.00016382152 -4.2719454f-5 4.369426f-5 7.496435f-5 3.119203f-5 0.00015855905 1.2739034f-5 7.506147f-5 6.0280996f-5 5.8465303f-5 -1.9294015f-5 9.19916f-5; -0.0001077281 -6.410031f-5 -0.00015638382 -0.00017781068 -0.00014204436 0.00019516783 -9.9750876f-5 -1.9362285f-5 4.2778094f-5 0.00012933585 6.3214844f-5 -9.380368f-5 3.361182f-5 0.0001993805 1.9648305f-5 -5.457907f-5 -4.7172212f-5 -9.403144f-5 -0.00013725089 1.535667f-5 -7.965368f-5 9.126882f-6 0.00011936906 -9.018309f-6 -4.166499f-6 2.1959688f-5 -0.00020483673 2.9876988f-5 1.1521733f-5 4.047942f-5 -0.000114734896 1.9828014f-5; -7.337908f-5 4.8569767f-5 -0.00010016533 5.381779f-5 7.80071f-5 2.0464331f-5 5.693881f-5 0.00013255689 6.6924746f-5 0.00019174152 6.5294165f-7 -0.00018148121 -1.38955575f-5 -9.108119f-5 1.3291072f-5 -8.9257795f-5 -3.19133f-5 0.00019684753 0.00013937353 7.760495f-6 9.5373965f-5 -8.8629524f-5 -2.5396184f-5 -0.000174673 -0.000105252926 5.822122f-5 -2.7217613f-5 -0.00012075265 -1.0472349f-5 6.4282845f-5 -0.00014192247 3.134949f-5; 1.3229305f-5 6.5637534f-5 0.00024242204 -0.00018743203 -2.5201256f-5 -2.0596031f-5 4.5691454f-6 -0.00015471819 9.393275f-5 0.00013544998 -7.967779f-5 -3.0933923f-5 2.719893f-5 6.8929835f-6 0.000103082566 5.8968163f-7 6.2360064f-5 -9.6603326f-5 -0.00020615995 -8.778267f-5 -0.00024846673 2.4875786f-5 -1.548694f-5 0.00014435446 7.499008f-5 -2.4984192f-5 -6.1246334f-5 -0.00014545007 -6.793275f-5 0.0001824709 -0.00011101222 -2.7453976f-5; -2.2740616f-5 -7.696486f-5 1.3907084f-5 7.9247046f-5 2.4819094f-5 9.735042f-5 1.8477944f-5 -0.00026380926 -8.678693f-5 5.9581846f-5 -3.392635f-5 9.227155f-5 2.8709695f-5 -0.00012130848 8.174161f-5 0.00014467505 -0.00015725286 2.6391235f-5 0.00016082506 -0.0001525441 -6.8175787f-6 -6.782615f-5 -2.944449f-5 -6.72308f-5 4.4934317f-5 -9.869582f-5 -1.1900483f-5 9.956004f-5 0.00017628977 0.00012351894 1.0605775f-5 -0.00014333123; -1.9166857f-5 -1.5859869f-5 -0.0001422217 0.00018814215 -0.00014621306 -9.1370675f-5 -2.857298f-6 6.123725f-5 -4.9520473f-5 0.00010888888 1.1925431f-5 -0.0001404149 -1.5226096f-5 0.000104234525 9.008095f-5 1.5647378f-5 0.00016053452 5.0596565f-5 -4.722956f-5 8.744778f-5 0.00023447916 6.569003f-5 -8.444372f-5 8.2286915f-6 0.000104515384 -4.815788f-6 -1.6078395f-5 0.00013036687 0.00010218032 -0.00020275667 -6.469113f-5 0.00012088423; 0.00011470342 9.73572f-5 0.00024630656 0.000110132896 0.0001197487 -5.581413f-6 1.2791754f-5 -3.4887013f-5 -4.499801f-5 -6.321256f-5 -6.1914136f-5 0.00015609441 -0.00013149987 -9.848024f-5 -5.6994304f-5 4.820558f-5 -4.026606f-5 -0.00024225845 -5.061019f-5 0.00016964384 -8.177574f-5 7.226926f-5 0.0002029227 9.5293704f-5 1.7909579f-5 8.9726265f-5 0.00015671499 -3.351027f-6 -0.00015060733 -4.8278973f-5 2.0857537f-5 -7.548793f-5; -0.00022084026 -9.0658395f-5 -0.00011320979 -0.00010201467 -7.0517846f-5 -0.000103105514 0.00015027537 -1.9837491f-5 2.399552f-6 -0.00017422538 0.00010130325 7.7387136f-5 3.5617802f-5 -0.00019887039 6.4667256f-5 -6.750992f-7 -3.337703f-5 -4.8389255f-5 6.023948f-5 0.00014555048 -3.2175765f-5 6.6452514f-5 0.00010843341 -0.00018370894 0.00020741967 -2.1191194f-5 6.073573f-5 -3.3258828f-5 0.00010722261 0.00013158047 -0.00010924604 -9.216183f-6; 5.4780747f-5 -0.0003333543 0.00011776382 4.3496322f-5 -2.6848142f-5 1.290623f-5 -0.00012337601 -1.1477379f-5 0.00014269943 0.00013475996 9.481841f-5 1.6143511f-5 1.0612356f-5 -6.3082734f-5 -8.1219885f-5 6.778973f-6 0.00017800451 0.00014663895 0.0002840219 0.00020730388 -7.5503216f-5 1.7989278f-5 -0.00029786857 9.549613f-6 0.00012508479 4.1113886f-5 -0.00018869482 9.415384f-5 5.207556f-5 -1.9692068f-5 -1.2046044f-5 4.0358077f-6; 1.4737431f-5 -1.6053515f-5 6.301106f-6 -6.7161505f-5 6.156151f-5 -1.8843655f-5 -3.7794183f-5 -4.001278f-5 7.402184f-5 9.44986f-5 -6.540499f-5 9.916706f-5 7.3302053f-6 0.00014806012 -0.00010955876 -5.758521f-5 -1.3403496f-5 1.07390115f-5 0.00012690062 -6.453337f-5 8.509866f-5 7.2698545f-5 -1.9212115f-5 8.335749f-5 0.00013685325 -0.00018055117 8.770716f-5 9.3171955f-5 4.9919763f-5 -4.007098f-5 -5.073131f-5 1.9752182f-5; -0.000110297646 7.3692245f-5 5.7002613f-5 3.757033f-5 6.936504f-5 7.738579f-5 -0.00012802749 -2.4713056f-5 -3.3093816f-5 -7.5007716f-5 5.9600443f-5 0.00016254991 6.8976864f-5 -2.0663105f-5 -4.115667f-5 0.000120786826 -6.6412256f-5 9.285241f-5 7.668568f-5 -0.00012417938 -6.107705f-5 1.3646267f-5 -9.509996f-5 -2.586263f-5 1.7628112f-5 -0.00014053403 0.00016938079 -0.00014554722 -9.815545f-5 -1.3750721f-5 6.937025f-5 0.00013471323; -4.3220366f-6 7.636895f-5 5.5689932f-5 -5.831289f-5 5.8981463f-5 5.907253f-5 -2.1172094f-5 -0.00015647814 9.09139f-5 0.000121287034 -4.990922f-5 1.6338497f-6 8.270634f-5 0.00016870494 0.00020896753 9.420973f-5 -0.000119361604 0.00014137046 7.405659f-5 0.00010401659 5.296225f-5 -0.00010834467 4.1371688f-5 1.828621f-5 -2.0364396f-5 7.3975134f-5 -8.620203f-5 1.9471541f-5 4.2867465f-5 0.00014634254 -0.00014190121 5.97361f-5; 2.842149f-5 -0.0001390108 0.00024771033 4.150737f-5 -6.279219f-6 -0.00013444283 3.396861f-5 -0.00017824108 7.010992f-5 -8.14772f-5 8.830449f-5 -8.062201f-5 -2.5002722f-5 0.00010167737 8.430926f-6 -6.642483f-5 6.2908315f-5 7.658701f-6 -0.0001287037 0.00011137576 9.553649f-5 -3.1438514f-5 -0.00013510555 0.00012291994 0.00012074464 -2.1073402f-5 -8.7760425f-5 2.7995618f-5 -1.8295434f-5 -9.720898f-5 2.8284183f-5 0.0001415438; -5.8446352f-5 -0.00015122785 0.0001593085 -4.641674f-5 6.807525f-5 0.00012329685 -4.8850063f-5 -0.00012529713 6.974015f-5 7.9347956f-5 -0.00012120158 3.2258909f-6 -0.00015234207 2.0792495f-5 5.8010373f-6 -0.0001827348 8.290032f-5 1.9521303f-5 5.85035f-7 -0.00014357359 5.9965594f-5 4.725024f-5 -4.8724283f-5 6.386635f-5 -1.8446852f-5 0.00016194872 4.377998f-6 -0.00011119193 0.00013444398 -2.88202f-6 0.00018299818 0.0001299258; 5.4772045f-5 2.6743592f-5 -3.7950605f-5 -6.456288f-5 2.2857123f-5 0.000117074356 -2.581451f-7 -2.326965f-6 -9.346618f-5 1.4302142f-5 5.834832f-5 7.489292f-5 -8.411193f-6 4.2638905f-5 0.0002049622 7.210017f-5 -0.00011986838 -7.240059f-5 -0.00014109987 3.662768f-5 1.1984916f-5 -2.1797452f-5 -0.00012682563 5.34511f-5 3.416906f-5 -7.218885f-5 7.393605f-5 1.0447531f-5 0.00017184565 -7.3962874f-6 -3.7607245f-7 2.0655369f-5; 0.00024575583 4.477467f-5 -3.322859f-5 -6.786506f-5 1.1935437f-5 -9.7573036f-5 -0.00010657273 -0.00011226045 -2.9953346f-6 -0.00025970914 9.5138326f-5 -9.4929805f-5 6.62886f-5 -1.885272f-5 8.4604995f-5 8.6286884f-5 -8.7978704f-5 -2.7180193f-5 -7.795209f-5 0.00012371145 1.4649836f-5 -4.0753937f-5 8.223443f-5 -0.0002055546 -1.2735223f-5 -7.979016f-5 -9.1610986f-5 0.00015620203 6.519782f-5 0.00013788766 -4.3885564f-5 -0.00020532678; -5.7761838f-5 0.000118674434 -4.022245f-6 2.3444218f-5 4.0354745f-8 0.00013102908 0.00014494953 3.734242f-5 4.9768067f-5 1.0978064f-5 -5.304449f-5 -2.5410567f-5 -0.00018835458 -0.00016039256 -9.446115f-5 0.0001017167 1.6507973f-5 -7.049449f-5 -4.167426f-6 8.723402f-5 0.00022946806 -2.8688223f-6 -8.483513f-5 4.748344f-5 5.818796f-5 7.439898f-5 1.3552518f-5 4.2757314f-5 4.830486f-7 4.0089722f-5 0.0001720303 -2.8506902f-5; 5.4004624f-5 -6.396038f-5 2.132994f-5 2.1886752f-5 -4.350948f-5 9.9748235f-5 7.256524f-5 0.00010073487 -0.00012668663 -0.00015849188 7.1473455f-6 -5.1716965f-5 0.00016950203 -4.9960803f-5 9.7575044f-5 -0.000116846546 -1.3026327f-5 1.3299394f-5 4.996746f-5 5.7434223f-5 0.0003126827 -0.00016075815 -0.00010500814 0.00010206681 7.102943f-5 0.0001681586 -8.869655f-5 -6.793002f-5 -1.4030315f-5 4.4765555f-5 6.598368f-5 2.9958126f-6; 6.4925875f-6 -6.126039f-5 0.00013805811 -3.796978f-5 -1.9054537f-6 9.732556f-5 -6.473637f-5 -0.00015415557 -0.00011108422 0.00016246777 -7.696852f-6 1.1449648f-5 -0.00015430289 4.2633474f-5 3.9564187f-5 0.00015013127 1.3399546f-5 6.9335816f-5 -4.1667856f-5 5.962422f-6 1.0195803f-5 0.00024106311 -4.3936325f-5 -5.064092f-6 0.00018281917 0.00013052553 8.648647f-5 6.500058f-5 -0.00013796729 6.188306f-5 5.4759174f-5 0.00024666276; 7.428961f-5 7.805581f-5 -2.0048912f-5 -8.7992f-5 -8.2090424f-5 -3.894516f-5 -7.4187235f-5 -8.421134f-6 0.00011872028 -3.3907768f-6 5.583108f-5 5.5117518f-5 -5.750363f-5 -0.00011956834 6.628877f-5 3.7464662f-5 -8.01679f-5 -2.2944869f-5 -0.000105784944 -6.595276f-5 -4.297851f-5 -5.5946744f-5 -0.00014952742 -4.398651f-5 -3.8181537f-5 0.00018228234 0.00012093744 -7.2758885f-5 8.365641f-5 0.00014493059 -0.000103712024 -2.7459113f-5; -1.1372291f-5 -1.8804054f-5 -7.936723f-5 -2.340988f-5 -3.75854f-5 1.6863713f-5 -0.00014024075 3.3528486f-5 -5.9627797f-5 9.936325f-5 -0.00024060751 0.00021471917 -0.00010987669 8.045311f-5 5.9758335f-5 0.0001100237 -6.767041f-5 6.5000546f-5 -2.8979757f-5 -5.129556f-5 2.8350714f-5 -8.325591f-5 -6.2756386f-5 6.0081635f-5 -0.0001399508 3.728897f-5 5.897405f-5 -0.00016271598 6.146232f-5 -2.9611121f-5 -0.00010624628 -6.1504106f-5; 0.000118863165 4.1103925f-5 0.0002137884 -0.00021993836 4.418324f-5 0.0001527344 -5.363446f-6 4.08742f-5 1.3710059f-5 -7.280649f-5 -3.1318545f-5 -4.696824f-5 3.5970716f-5 -9.051088f-6 -5.2216375f-5 -8.007393f-6 4.3387427f-5 -1.26947225f-5 0.00018864355 9.461782f-5 -3.053834f-5 0.00014947265 -7.126701f-5 0.00017512559 -9.375449f-5 -0.00014118603 0.00017528646 0.00022708872 -2.9811661f-5 2.87172f-5 4.526286f-5 -0.000136196; 0.00023604267 -7.345335f-5 1.9743271f-5 -0.00013814776 0.00013518608 3.9130664f-5 7.3728945f-5 8.378082f-5 1.3037745f-6 -6.6842826f-5 -0.00011881675 8.374014f-5 0.00013452528 -8.002644f-5 6.423297f-5 -8.742404f-5 -1.945782f-5 -0.00012333973 -1.4735208f-5 -4.0307412f-5 -5.90664f-6 3.714129f-5 -4.142277f-5 9.1647045f-5 2.3895907f-5 0.00016131696 0.000107153595 -8.047365f-5 -0.000103391816 0.00011903859 -4.755431f-5 -0.00010726879; 8.334862f-6 5.9595142f-5 4.7555448f-5 7.365213f-5 -3.2088086f-5 7.126067f-5 7.014639f-5 -0.00013384587 -8.863688f-5 -7.669142f-5 0.00019823949 3.551201f-5 0.00016557051 -0.0001284465 -0.00012736715 -6.684718f-5 -6.0223407f-5 -0.000116654184 -2.9300352f-5 7.6195174f-5 6.37634f-5 -7.894244f-5 2.6743777f-5 -7.59036f-5 4.9241884f-5 -8.201789f-5 3.6196663f-5 4.1520534f-6 9.076877f-5 -1.4967364f-5 4.200922f-6 0.00017322104], bias = Float32[-0.16792223, 0.17108434, -0.033846047, 0.072960645, 0.09885196, -0.16291727, -0.1316708, -0.008449579, 0.08691945, 0.17556772, -0.017550137, 0.038493097, 0.09326402, 0.026129022, 0.023156637, -0.056932405, -0.12245663, 0.13735487, -0.13388035, 0.029842371, 0.120385244, -0.11942054, 0.014552843, 0.14347413, -0.16169512, -0.07126387, -0.117433324, -0.08230025, -0.05788686, 0.14101708, -0.12888993, -0.02178326]), layer_4 = (weight = Float32[1.8479204f-5 -0.000121261524 -2.2483877f-5 9.626305f-5 -0.00012343402 -0.00015784666 -0.0001480165 2.5187019f-5 0.00010233112 0.00011360722 3.428437f-5 -1.3636172f-6 -0.00014541375 8.601685f-5 -7.000134f-5 8.7241504f-5 -8.910055f-5 5.5363824f-5 1.1091162f-5 6.28583f-5 5.8212703f-5 -2.9207102f-5 -0.0001204569 -2.6079564f-5 -6.8136706f-6 0.00015402751 -0.0001334096 -2.9714638f-6 -1.91305f-5 1.4484578f-5 1.4043293f-5 -0.00016508899; -4.5963066f-6 8.1387465f-5 0.00018817463 0.00023030551 -3.9679027f-5 -2.134191f-5 -5.584807f-5 0.000110629815 -4.639799f-5 1.843731f-5 -0.0002966492 5.0245108f-5 0.00014744225 6.004136f-6 7.824609f-5 -0.00024274473 3.869522f-6 2.3414763f-5 6.510159f-5 9.657588f-5 -8.334729f-5 -8.185208f-5 -2.7925384f-5 -2.6427975f-5 -9.910939f-5 -7.3479123f-6 -5.565102f-6 5.600881f-6 2.9896642f-5 1.8901861f-5 -8.8975255f-5 -8.565923f-6], bias = Float32[-0.09112734, 0.0832014])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))Similar to most DL frameworks, Lux defaults to using Float32, however, in this case we need Float64
const params = ComponentArray{Float64}(ps)
const nn_model = StatefulLuxLayer{true}(nn, nothing, st)StatefulLuxLayer{true}(
Chain(
layer_1 = WrappedFunction(Base.Fix1{typeof(broadcast), typeof(cos)}(broadcast, cos)),
layer_2 = Dense(1 => 32, cos), # 64 parameters
layer_3 = Dense(32 => 32, cos), # 1_056 parameters
layer_4 = Dense(32 => 2), # 66 parameters
),
) # Total: 1_186 parameters,
# plus 0 states.Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses
where,
function ODE_model(u, nn_params, t)
χ, ϕ = u
p, M, e = ode_model_params
# In this example we know that `st` is am empty NamedTuple hence we can safely ignore
# it, however, in general, we should use `st` to store the state of the neural network.
y = 1 .+ nn_model([first(u)], nn_params)
numer = (1 + e * cos(χ))^2
denom = M * (p^(3 / 2))
χ̇ = (numer / denom) * y[1]
ϕ̇ = (numer / denom) * y[2]
return [χ̇, ϕ̇]
endODE_model (generic function with 1 method)Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)
axislegend(ax, [[l1, s1], [l2, s2]],
["Waveform Data", "Waveform Neural Net (Untrained)"]; position=:lb)
fig
endSetting Up for Training the Neural Network
Next, we define the objective (loss) function to be minimized when training the neural differential equations.
const mseloss = MSELoss()
function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
return mseloss(waveform, pred_waveform), pred_waveform
endloss (generic function with 1 method)Warmup the loss function
loss(params)(0.0023192591194921524, [-0.025679248922184945, -0.0248047776253177, -0.023930306328450605, -0.022459077708410283, -0.02037033653470165, -0.017634821381681032, -0.014215971709150909, -0.010069006416113887, -0.005145525599023101, 0.0006035714509276479, 0.007216114093371033, 0.014698182706778683, 0.022985855103596804, 0.03185993389341407, 0.04079261317714077, 0.04863478204275979, 0.05302920542646642, 0.049396650967285884, 0.029770387299834736, -0.015764131622083427, -0.08443068355747022, -0.1249595674993928, -0.059197913167016995, 0.06337908077376352, 0.11157545425893811, 0.0811510499139864, 0.032655785652469446, -0.004243292059390436, -0.02610317925707291, -0.03681097686693647, -0.04041229373584437, -0.03975038870957541, -0.03662700453059925, -0.0321458859837731, -0.02697768783902325, -0.02153150822912155, -0.01605692898272395, -0.010707499269291753, -0.005578179580374028, -0.0007287632222006621, 0.0038023258911917615, 0.00798871104994561, 0.01180744991812969, 0.015242663980616312, 0.018271749646606588, 0.020871248945668646, 0.02301134581525401, 0.024654292127673224, 0.025752093124608525, 0.026243759938627768, 0.02605207583765939, 0.025079500002745846, 0.023203564659115883, 0.020271089412853695, 0.01609260336890602, 0.010438727048615287, 0.0030410812643323903, -0.006384816386784802, -0.018075821167404898, -0.032035833857829564, -0.04757802963622244, -0.062237275859881355, -0.0694531951959364, -0.05509822748999199, 0.0005889886549748257, 0.09113466359786321, 0.1310609207310351, 0.04994523935635099, -0.05689709419031927, -0.09417325712434414, -0.07791618021755911, -0.0467095924654129, -0.01859892269276973, 0.0021555696613844355, 0.016104756006575495, 0.024786695673402245, 0.02964614413300726, 0.03177704836586437, 0.03197862881199061, 0.030809929924052668, 0.028668942087612483, 0.025838213436905108, 0.02251748940026476, 0.01885725203930474, 0.014965504326316098, 0.010921356057332953, 0.006796518012545093, 0.002635299548659503, -0.0015078832119500376, -0.0056000992585625645, -0.009590548869145314, -0.013443045647184146, -0.01709845058900661, -0.020502500721412978, -0.02357251389369566, -0.026215135760292087, -0.028298737040104806, -0.029654489348629416, -0.030053658195575605, -0.02918394810282521, -0.02662673409163263, -0.02180341451283335, -0.013956280547062223, -0.0021087123787050916, 0.014784975092062754, 0.03732814102604237, 0.06381236387683192, 0.08540083306960823, 0.07789167742949824, 0.007835978506848229, -0.10070720376407288, -0.1276642258065718, -0.04432763725513215, 0.04074624478613359, 0.07388494431355996, 0.07101855038390893, 0.054429196154219185, 0.0356212533904544, 0.018858309414088205, 0.005220351893287032, -0.00535485643136124, -0.013290735190700483, -0.019034877218328774, -0.02300337095080485, -0.025522144782713172, -0.026863444279740373, -0.027231286450501817, -0.0267915822513703, -0.02567068688014666, -0.023967920282408392, -0.021764217371696065, -0.01911830560524247, -0.016085581308614663, -0.012700868229235454, -0.00900660750913898, -0.005027221300044713, -0.0007951348562083519, 0.003648369033771454, 0.00827812075710593, 0.013034030534173369, 0.017849557630078564, 0.022617652926210693, 0.027199048390900937, 0.03136233896546854, 0.03476647455941198, 0.03688516689430167, 0.0368980026666889, 0.03353638153850386, 0.02487037634030461, 0.008171688683095719, -0.019675771933563357, -0.05918798249164297, -0.09826989778613192, -0.09576665538396008, -0.006297375116834064, 0.10714718981480661, 0.11761141712281381, 0.046273888752087204, -0.018983901416324094, -0.050776161822454695, -0.05834908955899362, -0.053728811821941655, -0.04419499438471625, -0.03335708217012476, -0.022841236117548933, -0.013318657167549448, -0.004990327742291687, 0.002124352273938042, 0.008085180250832184, 0.013003175710945003, 0.016966575684236457, 0.02006941475587743, 0.022384289612491382, 0.023985428971574317, 0.02491918063518844, 0.025228058416986682, 0.024943300034661876, 0.024085658264213716, 0.022667378620517604, 0.020691526401284285, 0.018162187775865663, 0.015061327817179943, 0.011381217575685762, 0.00710100086123553, 0.0022134949169542455, -0.0033078215390008004, -0.009450418618446806, -0.016183583298432516, -0.023402434628464933, -0.030901202510339707, -0.03823144605168246, -0.0445374814233703, -0.048199742518590924, -0.046232802739660406, -0.033476006848840445, -0.0024396368558002248, 0.05169242252610585, 0.11015818912400929, 0.10551219611426906, -0.002438306024174045, -0.10423583080112273, -0.10586039327546559, -0.05334797738997444, -0.004318424399212751, 0.025665779910217148, 0.039930913409155284, 0.04429701186905635, 0.043041583274987284, 0.038796795782148855, 0.03309034706335514, 0.026793042141062536, 0.02039459196582214, 0.014179011472500846, 0.008285396708893895, 0.0028024825871911514, -0.0022328128835325525, -0.006792018331531639, -0.010881020732635693, -0.014486681959512708, -0.017609343383842124, -0.020237355400572855, -0.022369980882450197, -0.023984627293130716, -0.025059406922978898, -0.025562120547741762, -0.025450210933502204, -0.0246675191791299, -0.02314223619341742, -0.020793357797375924, -0.01750388984670667, -0.013150242179264986, -0.007577267928824017, -0.0006272252415455213, 0.007879041338981868, 0.018044839411298714, 0.029813818405713625, 0.042675251129656484, 0.0550745169330953, 0.06309022579360751, 0.05818420787289884, 0.02554488990896564, -0.04606835667360864, -0.12182471098882187, -0.10479333399776065, 0.009475997621701884, 0.09132892235672099, 0.09407386423553768, 0.060646937767724104, 0.025498144981459397, -0.00965064780480534])Now let us define a callback function to store the loss over time
const losses = Float64[]
function callback(θ, l, pred_waveform)
push!(losses, l)
@printf "Training %10s Iteration: %5d %10s Loss: %.10f\n" "" length(losses) "" l
return false
endcallback (generic function with 1 method)Training the Neural Network
Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
optprob, BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
callback, maxiters=1000)retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [0.000134866641019471; -6.086827124818228e-5; -2.18156237679068e-5; -0.0001018360853777267; 9.518320439383388e-6; 9.755264181876555e-5; -8.027584408409894e-5; 0.00016536054317839444; -0.00012283978867344558; -3.1131650757743046e-5; -1.5506975614698604e-5; 0.00014282212941907346; -4.820002868655138e-5; -9.392016181664076e-6; -2.2532549337483943e-5; 5.5018765124259517e-5; -5.494487413670868e-5; 2.955016680061817e-5; -0.00018670462304726243; -9.627539111534134e-5; -0.00012719715596176684; 0.00016724379383958876; 0.00010966541594825685; -0.0001452295546187088; 7.603982521686703e-5; -6.387787288986146e-5; -0.00011855418415507302; 8.658943988848478e-5; 0.00015967870422173291; -9.286082786275074e-5; 4.985584018868394e-5; 9.98468603938818e-5;;], bias = [0.596187949180603, -0.6482552289962769, -0.23217427730560303, -0.45540428161621094, -0.8901481628417969, -0.1967073678970337, 0.7987875938415527, 0.4284961223602295, -0.40056610107421875, 0.15956735610961914, -0.01093912124633789, -0.03423440456390381, -0.13895881175994873, 0.7470607757568359, -0.33783185482025146, -0.9743046760559082, -0.12773585319519043, -0.5957316160202026, -0.918546199798584, 0.3225581645965576, 0.38599932193756104, 0.46897387504577637, -0.08547067642211914, -0.6247744560241699, 0.3157527446746826, 0.8093822002410889, 0.5723259449005127, 0.9110771417617798, 0.3229248523712158, -0.5108339786529541, 0.34849047660827637, -0.21961939334869385]), layer_3 = (weight = [-8.926591544877738e-5 0.0001271493820240721 1.355249651169288e-6 -9.066618076758459e-5 0.00010753438255051151 -4.883634392172098e-5 -5.0133774493588135e-5 -0.00010982101957779378 -1.3924720406066626e-5 -0.00014842844393569976 -4.40010117017664e-5 0.00010470116831129417 -6.111922994023189e-5 -2.3183349185273983e-5 3.640575232566334e-5 -3.8583668356295675e-5 4.759052171721123e-5 -9.343397687189281e-5 0.0001257338881259784 -0.0002273207064718008 4.740366784972139e-5 2.4019940610742196e-5 -5.7272514823125675e-5 -5.655725544784218e-6 3.453420868027024e-5 -7.92658029240556e-5 2.5768558771233074e-5 -6.59190254737041e-6 6.562397175002843e-5 3.686684794956818e-5 4.567530140775489e-6 6.357259553624317e-5; -0.0002670119283720851 -0.00014138963888399303 -0.00013708665210288018 -0.00011134130909340456 2.839010994648561e-5 -1.456270274502458e-5 -0.00012981360487174243 -2.7648442483041435e-5 -8.626655471744016e-5 0.0001375194697175175 -1.972119753190782e-5 5.361830517358612e-6 -0.00025962956715375185 8.586419426137581e-5 0.00010760478471638635 2.637031002450385e-7 5.890399552299641e-5 -3.439670035731979e-5 0.0001094111503334716 4.674156116379891e-6 2.2627930320595624e-6 -0.0001947374694282189 -3.777692108997144e-5 -0.0001101541129173711 -3.6522807931760326e-5 -0.00025190654559992254 2.469361606927123e-5 -3.3183740015374497e-5 0.0001937649940373376 -0.00012616688036359847 4.161094329901971e-5 -7.07853541825898e-5; -7.617571827722713e-5 -1.5221440662571695e-5 -3.6111076042288914e-5 0.00012157452874816954 -6.654173193965107e-5 1.1459002053015865e-5 4.6367622417164966e-5 -3.453728595559369e-6 -0.00010836191358976066 6.017402483848855e-5 1.9261093257227913e-5 -0.00014976202510297298 5.1565504691097885e-5 9.130991202255245e-6 0.0001544108527014032 1.8040637996818987e-6 2.3114860596251674e-5 -4.9094163841800764e-5 -5.13850427523721e-5 -2.741505522863008e-5 -5.124018935021013e-5 -9.126013173954561e-5 -0.00022966247342992574 -0.0001560166710987687 1.157075985247502e-6 -8.713528222870082e-5 9.48149390751496e-5 -0.00014354732411447912 -0.00012639361375477165 -0.00018863505101762712 -1.4192917205946287e-6 8.051460463320836e-5; -6.425934407161549e-5 -3.697932334034704e-5 0.00018758275837171823 -7.3173064265574794e-6 1.3022076927882154e-5 -1.8890104911406524e-5 -3.144254515063949e-5 -7.21268734196201e-5 -7.352473039645702e-5 -9.973171836463735e-5 -9.680505172582343e-5 -2.0598106857505627e-5 -0.0001033810040098615 -4.468032420845702e-5 0.00011491870100144297 3.6679597542388365e-5 0.0001408526732120663 2.9614411687362008e-5 -5.5192005675053224e-5 -8.906838047550991e-5 -0.00010248077160213143 -3.552027192199603e-5 -5.0916914915433154e-5 -6.715955532854423e-5 -8.103359141387045e-5 0.00014674455451313406 -9.47652806644328e-5 0.00021752362954430282 0.00018389341130387038 -7.410090620396659e-5 0.0001225564192282036 7.947810081532225e-5; -6.738596857758239e-5 -8.93115284270607e-5 -0.0001304640172747895 5.0647093303268775e-5 9.59940007305704e-5 3.053608816117048e-5 -2.804544237733353e-5 6.215293251443654e-5 -2.6406580218463205e-5 -5.184479596209712e-5 2.442616460029967e-5 0.00017085943545680493 0.0001904481032397598 9.137973393080756e-5 -5.900539690628648e-5 0.00018449430353939533 4.0156963223125786e-5 2.5515941160847433e-5 -2.2844731574878097e-5 2.5007733711390756e-5 -1.8956679923576303e-5 6.031763041391969e-5 5.338010305422358e-5 0.00012101601896574721 8.964601875049993e-5 0.0002177903224946931 -4.246356547810137e-5 -0.00017798686167225242 5.731026249122806e-5 9.024583414429799e-5 5.640337258228101e-5 0.00012905432959087193; -1.0263779586239252e-5 -7.673001528019086e-5 -3.94672024413012e-5 0.00019675811927299947 0.0001596912188688293 3.043556171178352e-5 -0.00021834402286913246 -5.553643131861463e-5 7.705603820795659e-6 -4.745323894894682e-5 -0.00017646217020228505 -0.00013773568207398057 2.352297815377824e-5 7.131765596568584e-5 4.235611777403392e-5 8.16824467619881e-5 6.827913603046909e-5 -0.00011673913104459643 1.3830483112542424e-5 -4.991735841031186e-5 -0.0001532085589133203 -5.658990266965702e-5 -6.206926627783105e-5 4.881788117927499e-5 1.0390524494141573e-6 5.6706153372942936e-6 -4.488635022426024e-5 -5.891834007343277e-5 7.315494440263137e-5 0.00020035536726936698 -3.938681402360089e-5 -7.201840344350785e-5; 0.00010926267714239657 -1.7412732631783e-5 -2.798033347062301e-5 -9.992849663831294e-5 0.0002475100918672979 7.910590647952631e-5 5.469330426421948e-5 6.685949216489462e-8 0.00010413610289106146 2.5406202439626213e-6 -6.236748959054239e-6 -1.5909134162939154e-5 2.3462140234187245e-5 1.2387028618832119e-5 -0.0002675730502232909 -1.3177066648495384e-5 9.088885417440906e-5 -0.0001527809799881652 7.030127744656056e-5 -2.4672192012076266e-5 9.935075649991632e-5 4.8755980969872326e-5 -8.359791536349803e-5 -0.0002670779940672219 -0.00010592828039079905 -0.00015420980344060808 1.7652557289693505e-5 -3.1884243071544915e-5 0.00011062163684982806 -5.884900747332722e-5 -0.00012774026254191995 2.2638607333647087e-5; -1.8954291590489447e-5 3.255678529967554e-5 3.587321407394484e-5 -5.7337663747603074e-5 0.0001066511613316834 -0.0001280024298466742 -0.0001570741878822446 -6.199925701366737e-5 5.524623338715173e-5 2.2884545614942908e-5 0.00010430310794617981 -6.045283043931704e-6 -6.281781679717824e-5 -7.57820307626389e-5 -0.00010418124293209985 -0.0001386780058965087 -2.3946366127347574e-5 -6.701449456159025e-5 5.841551683261059e-5 -2.610613410070073e-5 1.7852362361736596e-5 -2.2024663849151693e-5 7.992082100827247e-5 -0.00010937178012682125 2.7628420866676606e-5 -5.299588792695431e-6 -0.00010500459029572085 -0.00010136794298887253 -4.995677954866551e-5 -5.522587161976844e-5 -1.1050815373891965e-5 -7.702720904489979e-5; -3.869339343509637e-5 6.048114300938323e-5 -0.00010209737229160964 0.00018429513147566468 -2.2352409359882586e-5 3.4792537917383015e-5 -6.912487879162654e-5 4.70114428026136e-5 -1.1644242476904765e-5 3.805059168371372e-5 -2.365293948969338e-5 -7.39039751351811e-5 -7.32258013158571e-6 -4.0900788008002564e-5 0.00016015190340112895 1.0778848263726104e-5 -0.00013308408961165696 2.8580494472407736e-5 0.00013405177742242813 7.49281098251231e-5 -0.00016382151807192713 -4.271945363143459e-5 4.3694260966731235e-5 7.496435136999935e-5 3.119203029200435e-5 0.0001585590507602319 1.2739033991238102e-5 7.506147085223347e-5 6.0280995967332274e-5 5.846530257258564e-5 -1.929401514644269e-5 9.19915983104147e-5; -0.00010772809764603153 -6.410031346604228e-5 -0.0001563838159199804 -0.00017781068163458258 -0.0001420443586539477 0.00019516782776918262 -9.975087596103549e-5 -1.9362285456736572e-5 4.27780942118261e-5 0.00012933585094287992 6.321484397631139e-5 -9.38036828301847e-5 3.361182098160498e-5 0.00019938050536438823 1.9648305169539526e-5 -5.45790717296768e-5 -4.71722123620566e-5 -9.403144213138148e-5 -0.00013725088501814753 1.5356670701294206e-5 -7.965367694851011e-5 9.126882105192635e-6 0.0001193690623040311 -9.018309356179088e-6 -4.166498911217786e-6 2.1959687728667632e-5 -0.00020483673142734915 2.9876988264732063e-5 1.152173263108125e-5 4.047941911267117e-5 -0.00011473489576019347 1.9828014046652243e-5; -7.337908027693629e-5 4.856976738665253e-5 -0.0001001653290586546 5.3817788284504786e-5 7.800709863658994e-5 2.046433110081125e-5 5.693881030310877e-5 0.00013255688827484846 6.69247456244193e-5 0.00019174152112100273 6.529416509692965e-7 -0.00018148121307604015 -1.3895557458454277e-5 -9.108118683798239e-5 1.329107180936262e-5 -8.925779548007995e-5 -3.191329960827716e-5 0.00019684752624016255 0.00013937352923676372 7.76049455453176e-6 9.537396545056254e-5 -8.862952381605282e-5 -2.5396184355486184e-5 -0.00017467299767304212 -0.00010525292600505054 5.822121966048144e-5 -2.7217613023822196e-5 -0.00012075265112798661 -1.0472348549228627e-5 6.42828454147093e-5 -0.0001419224718119949 3.134948929073289e-5; 1.3229305295681115e-5 6.563753413502127e-5 0.00024242204381152987 -0.0001874320296337828 -2.5201255994034e-5 -2.059603139059618e-5 4.569145403365837e-6 -0.0001547181891510263 9.39327510423027e-5 0.00013544998364523053 -7.967778947204351e-5 -3.0933922971598804e-5 2.7198930183658376e-5 6.892983492434723e-6 0.00010308256605640054 5.896816333006427e-7 6.236006447579712e-5 -9.66033258009702e-5 -0.00020615995163097978 -8.77826678333804e-5 -0.0002484667347744107 2.4875786039046943e-5 -1.5486939446418546e-5 0.00014435446064453572 7.49900791561231e-5 -2.4984192350530066e-5 -6.124633364379406e-5 -0.00014545007434207946 -6.793274951633066e-5 0.0001824709033826366 -0.00011101221753051504 -2.745397614489775e-5; -2.274061625939794e-5 -7.696486136410385e-5 1.3907084394304547e-5 7.924704550532624e-5 2.481909359630663e-5 9.735042112879455e-5 1.8477943740435876e-5 -0.00026380925555713475 -8.678693120600656e-5 5.958184556220658e-5 -3.392634971532971e-5 9.227154805557802e-5 2.8709695470752195e-5 -0.00012130848335800692 8.174160757334903e-5 0.00014467505388893187 -0.00015725285629741848 2.6391235223854892e-5 0.00016082506044767797 -0.00015254410391207784 -6.8175786509527825e-6 -6.782614946132526e-5 -2.944449079222977e-5 -6.723080150550231e-5 4.4934316974831745e-5 -9.869581845123321e-5 -1.1900482604687568e-5 9.956004214473069e-5 0.0001762897736625746 0.0001235189411090687 1.0605775059957523e-5 -0.00014333122817333788; -1.9166856873198412e-5 -1.5859868653933518e-5 -0.00014222170284483582 0.00018814214854501188 -0.0001462130603613332 -9.13706753635779e-5 -2.857298113667639e-6 6.123725324869156e-5 -4.952047311235219e-5 0.00010888888209592551 1.1925430953851901e-5 -0.0001404148933943361 -1.522609636595007e-5 0.00010423452476970851 9.008094639284536e-5 1.5647377949790098e-5 0.0001605345169082284 5.059656541561708e-5 -4.722956145997159e-5 8.744777733227238e-5 0.00023447915737051517 6.56900301692076e-5 -8.444372360827401e-5 8.228691513068043e-6 0.00010451538400957361 -4.815788088308182e-6 -1.6078394764917903e-5 0.00013036686868872494 0.00010218031820841134 -0.00020275666611269116 -6.469112850027159e-5 0.0001208842295454815; 0.00011470341996755451 9.735720232129097e-5 0.0002463065611664206 0.00011013289622496814 0.00011974869994446635 -5.5814130064391065e-6 1.2791753761121072e-5 -3.488701258902438e-5 -4.499801070778631e-5 -6.321255932562053e-5 -6.191413558553904e-5 0.00015609440742991865 -0.00013149986625649035 -9.848023910308257e-5 -5.6994304031832144e-5 4.8205580242211e-5 -4.026605893159285e-5 -0.00024225845118053257 -5.061018964624964e-5 0.0001696438412182033 -8.177573909051716e-5 7.226925663417205e-5 0.00020292270346544683 9.529370436212048e-5 1.790957867342513e-5 8.972626528702676e-5 0.00015671498840674758 -3.351027089593117e-6 -0.00015060733130667359 -4.827897282666527e-5 2.0857536583207548e-5 -7.548792927991599e-5; -0.0002208402584074065 -9.065839549293742e-5 -0.0001132097895606421 -0.00010201466648140922 -7.051784632494673e-5 -0.00010310551442671567 0.00015027537301648408 -1.9837490981444716e-5 2.399551931375754e-6 -0.00017422538076061755 0.00010130325244972482 7.738713611615822e-5 3.561780249583535e-5 -0.00019887038797605783 6.466725608333945e-5 -6.750992156412394e-7 -3.3377029467374086e-5 -4.8389254516223446e-5 6.023947935318574e-5 0.0001455504825571552 -3.2175765227293596e-5 6.645251414738595e-5 0.00010843340714927763 -0.00018370893667452037 0.0002074196672765538 -2.1191193809499964e-5 6.073572876630351e-5 -3.325882789795287e-5 0.00010722260776674375 0.00013158046931494027 -0.00010924603702733293 -9.216182661475614e-6; 5.4780746722826734e-5 -0.00033335431362502277 0.00011776381870731711 4.349632217781618e-5 -2.6848141715163365e-5 1.2906230040243827e-5 -0.00012337601219769567 -1.147737930295989e-5 0.00014269942766986787 0.00013475996092893183 9.481840970693156e-5 1.6143510947586037e-5 1.0612356163619552e-5 -6.308273441391066e-5 -8.1219885032624e-5 6.778972874599276e-6 0.00017800451314542443 0.0001466389512643218 0.00028402189491316676 0.0002073038776870817 -7.550321606686339e-5 1.7989277694141492e-5 -0.00029786856612190604 9.549613423587289e-6 0.00012508478539530188 4.111388625460677e-5 -0.00018869481573347002 9.415383829036728e-5 5.2075560233788565e-5 -1.9692068235599436e-5 -1.2046043593727518e-5 4.035807705804473e-6; 1.4737431229150388e-5 -1.6053514627856202e-5 6.301106168393744e-6 -6.716150528518483e-5 6.156151357572526e-5 -1.8843655197997577e-5 -3.779418329941109e-5 -4.001277920906432e-5 7.402183837257326e-5 9.449860226595774e-5 -6.540498725371435e-5 9.916706039803103e-5 7.330205335165374e-6 0.00014806012040935457 -0.00010955875768559054 -5.758521001553163e-5 -1.3403496268438175e-5 1.0739011486293748e-5 0.00012690061703324318 -6.453337118728086e-5 8.509866165695712e-5 7.269854540936649e-5 -1.9212115148548037e-5 8.335748862009495e-5 0.0001368532539345324 -0.00018055117106996477 8.77071579452604e-5 9.317195508629084e-5 4.991976311430335e-5 -4.007097959402017e-5 -5.073130887467414e-5 1.975218219740782e-5; -0.00011029764573322609 7.369224476860836e-5 5.700261317542754e-5 3.757033118745312e-5 6.936504360055551e-5 7.73857900639996e-5 -0.00012802748824469745 -2.4713055609026924e-5 -3.3093816455220804e-5 -7.500771607737988e-5 5.9600442909868434e-5 0.00016254991351161152 6.897686398588121e-5 -2.066310480586253e-5 -4.115666888537817e-5 0.00012078682630090043 -6.641225627390668e-5 9.285240957979113e-5 7.668568287044764e-5 -0.00012417937978170812 -6.107705121394247e-5 1.3646266779687721e-5 -9.509996016277e-5 -2.5862629627226852e-5 1.7628111891099252e-5 -0.00014053402992431074 0.00016938078624662012 -0.00014554722292814404 -9.815544763114303e-5 -1.375072133669164e-5 6.937025318620726e-5 0.00013471323472913355; -4.322036602388835e-6 7.636895315954462e-5 5.568993219640106e-5 -5.8312889450462535e-5 5.898146264371462e-5 5.907252852921374e-5 -2.1172094420762733e-5 -0.00015647814143449068 9.091389802051708e-5 0.00012128703383496031 -4.99092202517204e-5 1.6338497061951784e-6 8.270634134532884e-5 0.00016870493709575385 0.00020896752539556473 9.420973219675943e-5 -0.00011936160444747657 0.00014137045945972204 7.40565883461386e-5 0.00010401658801129088 5.2962248446419835e-5 -0.00010834466957021505 4.1371687984792516e-5 1.8286209524376318e-5 -2.036439582298044e-5 7.397513400064781e-5 -8.620203152531758e-5 1.947154123627115e-5 4.2867464799201116e-5 0.0001463425433030352 -0.00014190121146384627 5.973609950160608e-5; 2.842148933268618e-5 -0.00013901079364586622 0.0002477103262208402 4.150737004238181e-5 -6.279219178395579e-6 -0.00013444283104036003 3.3968608477152884e-5 -0.00017824108363129199 7.010991976130754e-5 -8.14771992736496e-5 8.830449223751202e-5 -8.062201231950894e-5 -2.5002722395583987e-5 0.0001016773676383309 8.43092584545957e-6 -6.642482912866399e-5 6.290831515798345e-5 7.658701179025229e-6 -0.00012870370119344443 0.00011137576075270772 9.553648851579055e-5 -3.143851427012123e-5 -0.00013510555436369032 0.00012291994062252343 0.0001207446402986534 -2.1073401512694545e-5 -8.776042523095384e-5 2.7995618438581005e-5 -1.82954336196417e-5 -9.720897651277483e-5 2.828418291755952e-5 0.00014154380187392235; -5.84463523409795e-5 -0.0001512278540758416 0.00015930850349832326 -4.641673876903951e-5 6.807524914620444e-5 0.00012329684977885336 -4.8850062739802524e-5 -0.0001252971269423142 6.974014831939712e-5 7.934795576147735e-5 -0.00012120157771278173 3.225890850444557e-6 -0.0001523420651210472 2.079249497910496e-5 5.801037332275882e-6 -0.00018273480236530304 8.290031837532297e-5 1.9521303329383954e-5 5.850350248692848e-7 -0.00014357359032146633 5.9965594118693843e-5 4.72502397315111e-5 -4.8724283260526136e-5 6.38663477730006e-5 -1.8446851754561067e-5 0.00016194871568586677 4.377997811388923e-6 -0.00011119192640762776 0.00013444398064166307 -2.8820199986512307e-6 0.00018299817747902125 0.00012992580013815314; 5.477204467752017e-5 2.674359166121576e-5 -3.795060547417961e-5 -6.456288247136399e-5 2.285712253069505e-5 0.0001170743562397547 -2.5814509285737586e-7 -2.3269649318535812e-6 -9.346618026029319e-5 1.4302141607913654e-5 5.83483197260648e-5 7.489292329410091e-5 -8.411193448409904e-6 4.263890514266677e-5 0.00020496219804044813 7.210017065517604e-5 -0.00011986838217126206 -7.240058766910806e-5 -0.00014109986659605056 3.662767994683236e-5 1.1984915545326658e-5 -2.1797452063765377e-5 -0.0001268256310140714 5.345109821064398e-5 3.416906110942364e-5 -7.218885002657771e-5 7.393604755634442e-5 1.044753116730135e-5 0.0001718456478556618 -7.396287401206791e-6 -3.7607244962600817e-7 2.065536864392925e-5; 0.0002457558293826878 4.477467155084014e-5 -3.322858901810832e-5 -6.786506128264591e-5 1.1935437214560807e-5 -9.757303632795811e-5 -0.0001065727265086025 -0.0001122604517149739 -2.995334625666146e-6 -0.00025970913702622056 9.513832628726959e-5 -9.492980461800471e-5 6.628860137425363e-5 -1.8852719222195446e-5 8.460499520879239e-5 8.628688374301419e-5 -8.797870395937935e-5 -2.718019277381245e-5 -7.795209239702672e-5 0.00012371144839562476 1.4649835975433234e-5 -4.07539373554755e-5 8.223443001043051e-5 -0.00020555460650939494 -1.2735223208437674e-5 -7.979015936143696e-5 -9.161098569165915e-5 0.00015620203339494765 6.519781891256571e-5 0.00013788766227662563 -4.388556408230215e-5 -0.00020532678172457963; -5.776183752459474e-5 0.00011867443390656263 -4.022244866064284e-6 2.3444217731594108e-5 4.035474532315675e-8 0.00013102908269502223 0.00014494953211396933 3.734241909114644e-5 4.976806667400524e-5 1.0978063983202446e-5 -5.304448859533295e-5 -2.541056710470002e-5 -0.00018835457740351558 -0.00016039256297517568 -9.446115291211754e-5 0.00010171670146519318 1.6507972759427503e-5 -7.04944905010052e-5 -4.167426141066244e-6 8.723401697352529e-5 0.00022946805984247476 -2.8688223210338037e-6 -8.48351264721714e-5 4.7483441449003294e-5 5.8187961258227006e-5 7.439898035954684e-5 1.3552517884818371e-5 4.275731407687999e-5 4.830486091123021e-7 4.008972246083431e-5 0.00017203029710799456 -2.8506901799119078e-5; 5.4004623962100595e-5 -6.39603822492063e-5 2.132994086423423e-5 2.1886751710553654e-5 -4.350948074716143e-5 9.974823478842154e-5 7.256524258991703e-5 0.00010073486919282004 -0.00012668663111981004 -0.00015849187911953777 7.147345513658365e-6 -5.1716964662773535e-5 0.00016950203280430287 -4.996080315322615e-5 9.757504449225962e-5 -0.00011684654600685462 -1.3026326996623538e-5 1.3299393685883842e-5 4.996746065444313e-5 5.74342229811009e-5 0.00031268270686268806 -0.00016075815074145794 -0.00010500814096303657 0.00010206681326963007 7.10294334567152e-5 0.00016815859999042004 -8.869654993759468e-5 -6.793002103222534e-5 -1.4030314559931867e-5 4.476555477594957e-5 6.598368054255843e-5 2.995812565131928e-6; 6.492587544926209e-6 -6.126039079390466e-5 0.0001380581088596955 -3.796978126047179e-5 -1.905453700601356e-6 9.732555918162689e-5 -6.473637040471658e-5 -0.00015415556845255196 -0.000111084220407065 0.0001624677679501474 -7.696851753280498e-6 1.144964790000813e-5 -0.00015430289204232395 4.2633473640307784e-5 3.9564187318319455e-5 0.00015013126539997756 1.3399546332948375e-5 6.933581607881933e-5 -4.166785583947785e-5 5.962422164884629e-6 1.0195803042734042e-5 0.00024106311320792884 -4.3936324800597504e-5 -5.064091965323314e-6 0.0001828191743697971 0.00013052552822045982 8.648647053632885e-5 6.500058225356042e-5 -0.00013796729035675526 6.188305997056887e-5 5.475917350850068e-5 0.00024666276294738054; 7.428960816469043e-5 7.80558111728169e-5 -2.004891211981885e-5 -8.799199713394046e-5 -8.20904242573306e-5 -3.894515975844115e-5 -7.418723544105887e-5 -8.421134225500282e-6 0.0001187202797154896 -3.3907767829077784e-6 5.5831078498158604e-5 5.511751805897802e-5 -5.75036283407826e-5 -0.00011956834350712597 6.628876872127876e-5 3.7464662455022335e-5 -8.016789797693491e-5 -2.29448687605327e-5 -0.00010578494402579963 -6.595275772269815e-5 -4.2978510464308783e-5 -5.594674439635128e-5 -0.00014952741912566125 -4.3986510718241334e-5 -3.818153709289618e-5 0.00018228233966510743 0.00012093743862351403 -7.275888492586091e-5 8.365640678675845e-5 0.000144930585520342 -0.00010371202370151877 -2.7459112970973365e-5; -1.1372290828148834e-5 -1.880405397969298e-5 -7.936722977319732e-5 -2.3409880668623373e-5 -3.7585399695672095e-5 1.6863712517078966e-5 -0.00014024075062479824 3.352848580107093e-5 -5.962779687251896e-5 9.936324931913987e-5 -0.0002406075072940439 0.00021471916988957673 -0.0001098766879294999 8.045310823945329e-5 5.975833482807502e-5 0.00011002369865309447 -6.767040758859366e-5 6.500054587377235e-5 -2.897975718951784e-5 -5.1295559387654066e-5 2.8350714273983613e-5 -8.325590897584334e-5 -6.275638588704169e-5 6.008163472870365e-5 -0.00013995080371387303 3.728896990651265e-5 5.897404844290577e-5 -0.00016271597996819764 6.14623204455711e-5 -2.9611121135530993e-5 -0.00010624628339428455 -6.150410627014935e-5; 0.00011886316497111693 4.110392546863295e-5 0.00021378840028773993 -0.0002199383598053828 4.418324169819243e-5 0.0001527343993075192 -5.363445779948961e-6 4.087420165888034e-5 1.3710058738070074e-5 -7.280649151653051e-5 -3.131854464299977e-5 -4.696824180427939e-5 3.59707155439537e-5 -9.051088454725686e-6 -5.2216375479474664e-5 -8.007393262232654e-6 4.3387426558183506e-5 -1.2694722499873023e-5 0.00018864354933612049 9.461781883146614e-5 -3.053834007005207e-5 0.00014947264571674168 -7.126700802473351e-5 0.00017512559134047478 -9.37544900807552e-5 -0.00014118602848611772 0.00017528646276332438 0.00022708871983923018 -2.9811661079293117e-5 2.8717200621031225e-5 4.526285920292139e-5 -0.000136196002131328; 0.00023604267335031182 -7.345335325226188e-5 1.9743270968319848e-5 -0.00013814776320941746 0.0001351860846625641 3.9130663935793564e-5 7.37289446988143e-5 8.378081838600338e-5 1.3037745247856947e-6 -6.684282561764121e-5 -0.00011881675163749605 8.374013850698248e-5 0.00013452528219204396 -8.002643880899996e-5 6.423296872526407e-5 -8.742404315853491e-5 -1.9457820599200204e-5 -0.00012333973427303135 -1.4735208424099255e-5 -4.030741183669306e-5 -5.90664012634079e-6 3.714128979481757e-5 -4.142276884522289e-5 9.164704533759505e-5 2.389590736129321e-5 0.00016131695883814245 0.00010715359530877322 -8.047364826779813e-5 -0.00010339181608287618 0.00011903858830919489 -4.755430927616544e-5 -0.00010726878826972097; 8.334862286574207e-6 5.95951423747465e-5 4.755544796353206e-5 7.365213241428137e-5 -3.208808630006388e-5 7.126067066565156e-5 7.014638686086982e-5 -0.00013384586782194674 -8.863687980920076e-5 -7.669142360100523e-5 0.00019823948969133198 3.551201007212512e-5 0.00016557051276322454 -0.00012844649609178305 -0.00012736715143546462 -6.684717664029449e-5 -6.022340676281601e-5 -0.00011665418423945084 -2.9300352252903394e-5 7.619517418788746e-5 6.376340024871752e-5 -7.894243753980845e-5 2.674377719813492e-5 -7.590359746245667e-5 4.924188397126272e-5 -8.20178902358748e-5 3.6196663131704554e-5 4.152053406869527e-6 9.076877176994458e-5 -1.496736422268441e-5 4.200921921437839e-6 0.00017322103667538613], bias = [-0.1679222285747528, 0.17108434438705444, -0.03384604677557945, 0.07296064496040344, 0.09885195642709732, -0.16291727125644684, -0.13167080283164978, -0.008449578657746315, 0.08691944926977158, 0.17556771636009216, -0.017550136893987656, 0.03849309682846069, 0.09326402097940445, 0.026129022240638733, 0.02315663732588291, -0.05693240463733673, -0.1224566325545311, 0.13735486567020416, -0.1338803470134735, 0.029842371121048927, 0.12038524448871613, -0.11942054331302643, 0.014552842825651169, 0.14347413182258606, -0.16169512271881104, -0.0712638720870018, -0.1174333244562149, -0.08230025321245193, -0.05788686126470566, 0.14101707935333252, -0.12888993322849274, -0.021783260628581047]), layer_4 = (weight = [1.8479204300092533e-5 -0.00012126152432756498 -2.2483876819023862e-5 9.626305109122768e-5 -0.00012343401613179594 -0.0001578466617502272 -0.00014801649376749992 2.5187018763972446e-5 0.0001023311197059229 0.00011360721691744402 3.4284370485693216e-5 -1.3636172297992744e-6 -0.00014541375276166946 8.6016851128079e-5 -7.000134064583108e-5 8.724150393391028e-5 -8.910054748412222e-5 5.536382377613336e-5 1.1091162377852015e-5 6.285830022534356e-5 5.8212703152094036e-5 -2.9207101761130616e-5 -0.00012045689800288528 -2.6079564122483134e-5 -6.813670552219264e-6 0.00015402751159854233 -0.0001334096014034003 -2.9714638003497384e-6 -1.913050073198974e-5 1.4484578059636988e-5 1.4043293049326167e-5 -0.00016508898988831788; -4.596306553139584e-6 8.138746488839388e-5 0.0001881746284198016 0.00023030550801195204 -3.967902739532292e-5 -2.1341909814509563e-5 -5.584807149716653e-5 0.0001106298150261864 -4.639798862626776e-5 1.8437309336150065e-5 -0.00029664920293726027 5.024510755902156e-5 0.0001474422460887581 6.004136139381444e-6 7.824609201634303e-5 -0.0002427447325317189 3.869522060995223e-6 2.341476283618249e-5 6.510158709716052e-5 9.65758808888495e-5 -8.334728772751987e-5 -8.185207843780518e-5 -2.7925383619731292e-5 -2.642797517182771e-5 -9.910939115798101e-5 -7.347912287514191e-6 -5.565102128457511e-6 5.600881195277907e-6 2.9896642445237376e-5 1.890186103992164e-5 -8.897525549400598e-5 -8.565923053538427e-6], bias = [-0.09112734347581863, 0.08320140093564987]))Visualizing the Results
Let us now plot the loss over time
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")
lines!(ax, losses; linewidth=4, alpha=0.75)
scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)
fig
endFinally let us visualize the results
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(compute_waveform(
dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)
l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
s3 = scatter!(ax, tsteps, waveform_nn_trained; marker=:circle,
alpha=0.5, strokewidth=2, markersize=12)
axislegend(ax, [[l1, s1], [l2, s2], [l3, s3]],
["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
position=:lb)
fig
endAppendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 128 × AMD EPYC 7502 32-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
JULIA_CPU_THREADS = 16
JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 16
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = LiterateThis page was generated using Literate.jl.