Training a Neural ODE to Model Gravitational Waveforms
This code is adapted from Astroinformatics/ScientificMachineLearning
The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl
Package Imports
using Lux, ComponentArrays, LineSearches, OrdinaryDiffEq, Optimization, OptimizationOptimJL,
Printf, Random, SciMLSensitivity
using CairoMakieDefine some Utility Functions
Tip
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector
function one2two(path, m₁, m₂)
M = m₁ + m₂
r₁ = m₂ / M .* path
r₂ = -m₁ / M .* path
return r₁, r₂
endone2two (generic function with 1 method)Next we define a function to perform the change of variables:
@views function soln2orbit(soln, model_params=nothing)
@assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
if size(soln, 1) == 2
χ = soln[1, :]
ϕ = soln[2, :]
@assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2"
p, M, e = model_params
else
χ = soln[1, :]
ϕ = soln[2, :]
p = soln[3, :]
e = soln[4, :]
end
r = p ./ (1 .+ e .* cos.(χ))
x = r .* cos.(ϕ)
y = r .* sin.(ϕ)
orbit = vcat(x', y')
return orbit
endsoln2orbit (generic function with 2 methods)This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d_dt(v::AbstractVector, dt)
a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
b = (v[3:end] .- v[1:(end - 2)]) / 2
c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
return [a; b; c] / dt
endd_dt (generic function with 1 method)This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d2_dt2(v::AbstractVector, dt)
a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
return [a; b; c] / (dt^2)
endd2_dt2 (generic function with 1 method)Now we define a function to compute the trace-free moment tensor from the orbit
function orbit2tensor(orbit, component, mass=1)
x = orbit[1, :]
y = orbit[2, :]
Ixx = x .^ 2
Iyy = y .^ 2
Ixy = x .* y
trace = Ixx .+ Iyy
if component[1] == 1 && component[2] == 1
tmp = Ixx .- trace ./ 3
elseif component[1] == 2 && component[2] == 2
tmp = Iyy .- trace ./ 3
else
tmp = Ixy
end
return mass .* tmp
end
function h_22_quadrupole_components(dt, orbit, component, mass=1)
mtensor = orbit2tensor(orbit, component, mass)
mtensor_ddot = d2_dt2(mtensor, dt)
return 2 * mtensor_ddot
end
function h_22_quadrupole(dt, orbit, mass=1)
h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
return h11, h12, h22
end
function h_22_strain_one_body(dt::T, orbit) where {T}
h11, h12, h22 = h_22_quadrupole(dt, orbit)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
h11 = h11_1 + h11_2
h12 = h12_1 + h12_2
h22 = h22_1 + h22_2
return h11, h12, h22
end
function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
# compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2
@assert abs(mass1 + mass2 - 1.0)<1e-12 "Masses do not sum to unity"
h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
@assert mass_ratio≤1 "mass_ratio must be <= 1"
@assert mass_ratio≥0 "mass_ratio must be non-negative"
orbit = soln2orbit(soln, model_params)
if mass_ratio > 0
m₂ = inv(T(1) + mass_ratio)
m₁ = mass_ratio * m₂
orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
return waveform
endcompute_waveform (generic function with 2 methods)Simulating the True Model
RelativisticOrbitModel defines system of odes which describes motion of point like particle in schwarzschild background, uses
where,
function RelativisticOrbitModel(u, (p, M, e), t)
χ, ϕ = u
numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
denom = sqrt((p - 2)^2 - 4 * e^2)
χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = numer / (M * (p^(3 / 2)) * denom)
return [χ̇, ϕ̇]
end
mass_ratio = 0.0 # test particle
u0 = Float64[π, 0.0] # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4) # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize) # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, eLet's simulate the true model and plot the results using OrdinaryDiffEq.jl
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)
axislegend(ax, [[l, s]], ["Waveform Data"])
fig
endDefiing a Neural Network Model
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,
const nn = Chain(Base.Fix1(broadcast, cos),
Dense(1 => 32, cos; init_weight=truncated_normal(; std=1e-4)),
Dense(32 => 32, cos; init_weight=truncated_normal(; std=1e-4)),
Dense(32 => 2; init_weight=truncated_normal(; std=1e-4)))
ps, st = Lux.setup(Xoshiro(), nn)((layer_1 = NamedTuple(), layer_2 = (weight = Float32[3.4097506f-5; -5.3354128f-5; -0.00023229135; 2.6896036f-5; 6.4766f-5; 6.789436f-5; 0.00017620037; 6.1020804f-5; 0.00010646814; 0.00021610218; 4.5022465f-5; 9.842103f-5; 3.662282f-5; -3.9151517f-5; -4.410463f-5; -1.672994f-5; -8.749419f-5; -2.6822803f-5; -0.00010940525; -0.00014563721; 6.8300964f-5; 3.9568775f-5; -0.000107343905; 9.670898f-5; 8.210132f-5; 3.8217535f-5; 0.000110025416; 0.000121954756; 5.61381f-6; -2.1075979f-5; -0.00013300184; 8.509586f-6;;], bias = Float32[-0.8554672, 0.43835998, 0.18090737, 0.17270291, 0.7747139, -0.5374639, 0.7339504, -0.6907319, -0.9903889, -0.63582873, -0.94144166, -0.13805974, -0.14056087, -0.083960176, -0.12377334, -0.27236724, 0.47503495, 0.81609976, -0.56744766, 0.12224746, 0.7362548, 0.98105943, -0.7524651, -0.9390366, 0.19794619, 0.01823163, -0.85081375, -0.94594896, -0.0008701086, -0.69256246, 0.19172227, 0.36514962]), layer_3 = (weight = Float32[2.2019934f-5 9.1054695f-5 2.6725786f-5 0.00011013531 4.2111442f-5 5.4074735f-5 3.087161f-5 -1.3899451f-5 0.00016610106 0.00010314328 2.440423f-5 8.362616f-6 0.000112484 -2.5824377f-6 -1.1004563f-5 -7.225343f-5 2.842519f-5 3.1412343f-5 3.634596f-5 0.0002244043 -0.000107148335 -3.981481f-5 -0.00015213892 0.00017811962 9.190931f-5 -7.363268f-6 2.9204233f-5 -6.3991414f-5 5.9155147f-5 7.290606f-5 3.936544f-5 2.4967856f-5; -3.631928f-5 -2.098664f-5 8.7459244f-5 7.87874f-5 0.00014319802 0.000110248686 -8.994626f-5 -2.0634108f-5 -4.9877188f-5 -1.3193887f-6 -8.4122286f-5 -0.00019079648 -4.8626443f-5 -0.00021698626 -0.0001528242 -7.3084266f-5 -7.1552044f-5 -0.0001266186 -7.7003184f-5 8.508759f-5 5.853675f-5 4.136842f-5 9.481731f-6 9.315815f-6 1.30128865f-5 -8.58536f-5 0.000117913136 9.954056f-5 -0.0002253167 3.1915545f-6 -0.00014858408 -1.7624667f-5; 1.3660185f-5 5.9607126f-5 2.894238f-5 2.5221076f-5 5.1745745f-5 0.00010662021 8.439323f-5 -1.8862362f-5 -4.3768177f-5 1.1804941f-5 3.637802f-5 1.06177795f-5 -5.7796864f-5 -3.2275282f-5 -0.00022527062 0.00012949246 4.30095f-5 5.679032f-5 -5.645525f-5 9.60993f-5 2.0569172f-5 6.45608f-5 6.738495f-5 -6.9125985f-5 0.00010254904 2.9382274f-5 1.9133377f-5 -6.1148103f-6 9.075789f-5 -0.00016537742 -2.444119f-5 0.00010570185; 5.119318f-5 7.4614014f-5 -0.00022950782 0.00015979072 -0.00015588448 -2.759972f-5 5.4033346f-5 -7.009402f-5 0.00016923241 -0.00019113794 1.1821685f-5 0.00016013844 0.00015681746 -0.0001778512 2.4221225f-5 -0.0002188686 5.9236165f-5 -8.459041f-5 -3.581713f-5 0.00018988186 -3.6216778f-5 0.00017793814 -8.984873f-6 7.1828705f-5 5.814037f-5 -0.00012459893 0.00016681946 -5.725595f-5 -1.5240712f-5 -6.921707f-5 0.00013018755 -0.00022191594; 1.1235368f-5 1.349715f-5 9.251109f-5 -4.6744215f-5 0.00011731941 0.00020106576 -5.4907177f-5 1.8336243f-5 8.49108f-5 4.769928f-6 -2.885332f-5 -6.0636314f-5 -2.9623063f-5 7.029978f-5 7.3274954f-5 4.299529f-5 -8.130277f-6 -2.4465315f-5 7.231863f-5 -0.0001593162 9.563865f-5 0.000100666286 3.894711f-5 -6.121224f-5 -3.59988f-5 -3.7958267f-5 -5.4026546f-5 -7.773507f-5 -0.00012637945 7.280934f-5 2.0582434f-5 0.00023414352; -3.557338f-5 2.756441f-5 -0.00010337385 0.00010487824 0.00013306327 -2.012031f-6 5.7961926f-5 -2.8401466f-5 -4.2559775f-5 8.5633485f-7 -6.821588f-5 -7.619395f-5 3.081965f-5 -4.7928545f-5 -8.4630476f-5 4.9514896f-5 4.8127044f-5 -0.00011203089 7.2303716f-5 -3.5496876f-5 -4.9657963f-5 -6.6272914f-5 -2.4381083f-5 3.882193f-7 -3.0286585f-5 -6.9117115f-5 0.00015253774 6.7908644f-5 -8.014717f-5 -0.00012482634 -0.00017997762 0.00013130062; 0.00010582965 0.00022013833 0.00010590407 -2.1493453f-5 7.832588f-5 -5.8876398f-5 8.3462095f-5 3.4815414f-5 -0.00013060008 -3.3493896f-5 3.6277186f-5 0.0001451351 3.1081705f-5 -1.7285965f-5 -2.7875147f-5 -0.00012055418 -7.625183f-5 -6.581398f-5 0.00012290834 4.293252f-5 0.00015772885 0.00011260274 -5.6413726f-5 0.0001396575 0.00011371401 -0.000100728816 6.740591f-5 -0.00018602956 9.821792f-5 -6.462408f-5 1.3134823f-5 9.052437f-5; 0.0001422964 3.7859456f-5 -4.1237698f-5 -9.1145725f-5 -4.1365798f-5 5.3621632f-5 -3.3930097f-5 3.5047935f-5 1.8136274f-6 -7.3021394f-5 8.8892375f-5 0.00024709414 -0.00014537988 0.00013519828 -0.00011599323 6.340862f-5 -1.4834025f-5 9.397284f-5 9.060985f-5 1.973995f-5 -4.8892045f-5 8.783928f-5 2.9862964f-5 3.9845734f-5 -3.4142893f-5 -7.437028f-5 0.00012214838 -1.0997738f-5 7.017819f-5 -0.000100206395 2.3358578f-5 6.2369f-6; 4.645772f-5 6.606426f-6 6.40963f-5 -4.7561905f-5 1.7183833f-5 -4.9554168f-5 1.0981941f-5 -8.405134f-5 0.00010131721 8.581134f-5 5.4021234f-6 4.038699f-5 -5.4327825f-6 6.883496f-5 6.6141685f-5 -0.00012968888 -9.541193f-5 1.0767669f-5 8.868139f-6 -0.0001168999 0.00010504195 -2.4291914f-5 8.928573f-6 -0.00013418439 -1.8854545f-5 -2.6927686f-5 2.4590183f-5 2.0707052f-5 5.4417225f-5 0.00019555012 3.735106f-5 -7.583571f-5; -9.570882f-6 0.00010471379 7.7134945f-7 0.00023507488 -0.00029812823 -0.00012194035 -6.3150306f-5 -0.00022171577 0.00024056507 0.00011514026 8.9895155f-5 -7.9510355f-5 3.9293263f-5 -2.67341f-5 6.635397f-5 -5.7679415f-5 -0.00011449221 -0.00010660112 0.00010526115 -0.0001008436 -0.00023731838 3.792502f-5 -0.00018801977 0.00010788175 -6.6046574f-5 -6.473219f-5 -5.0701175f-5 -0.00010545325 0.00012679708 4.8442023f-5 5.6529698f-5 8.838228f-5; -2.1435353f-5 -0.00010857757 8.7237924f-5 2.1781387f-5 -0.00011873426 1.3891184f-5 -0.00016861719 -0.0002428301 -2.774912f-5 -0.00018039007 -0.00012225729 5.7911882f-5 2.9661158f-5 3.4046727f-5 0.00011700436 -0.00015084077 6.6748595f-5 8.944401f-5 0.00013883927 -0.00017002525 -7.893058f-5 -0.00014329488 0.0001439576 4.36528f-5 0.00013312466 4.945486f-5 -1.200961f-5 -0.000114581024 9.358418f-5 -8.825631f-5 9.704817f-6 0.0002423127; -2.5632355f-5 -7.205963f-5 -4.145115f-6 8.561346f-5 -0.00012784427 -5.029624f-6 7.2404284f-5 -2.775423f-5 -0.00015231135 -0.00019374976 1.1621169f-5 -6.1496634f-5 -9.140654f-5 -4.6793826f-5 -5.2789772f-5 3.9190913f-5 9.451057f-5 -3.329733f-5 9.4973395f-5 -5.959438f-5 -4.589289f-5 6.079221f-5 -6.265918f-5 0.00013474416 -6.913077f-5 -0.00013920982 7.3892165f-6 7.968019f-5 0.00012063571 2.2644035f-5 0.00020402737 0.00016630883; 0.00017571247 -0.00011754441 9.475308f-6 -5.1331476f-6 -0.00014763387 5.2771327f-5 -0.000110038905 2.4373508f-5 3.655785f-5 2.361267f-5 7.3629984f-5 -1.193336f-5 -7.894892f-5 7.443758f-5 -1.5575552f-5 4.0289844f-5 1.2749096f-5 -9.3275485f-5 4.2398762f-5 -6.996288f-5 -1.8956925f-5 -1.1643972f-5 2.0381643f-5 6.480221f-5 -8.0774225f-6 -7.0338265f-6 0.00011463044 9.971538f-5 -6.2388084f-5 -3.1588854f-5 -0.00020152524 0.00013340468; 0.00010832289 0.0001360652 -7.1582544f-5 -0.00015130821 -2.417678f-5 0.000122022124 2.922998f-5 -3.185171f-5 -3.323852f-5 -6.8669237f-6 -0.00016792123 1.0131541f-5 0.00019030592 0.000209951 1.842746f-5 5.8094924f-5 -0.00013603752 9.613665f-5 3.5310335f-5 0.00012460373 6.034756f-5 -1.4836867f-5 -9.724633f-6 1.2549166f-5 0.00013390848 -3.129618f-5 9.002069f-5 6.968469f-5 -1.18912585f-5 3.941971f-5 0.00014710086 7.1940325f-5; -0.00012792303 -8.342724f-5 -9.058795f-5 -0.0001335059 -0.00012840098 -5.5018976f-5 -6.583351f-5 0.00022190223 8.327143f-6 7.8249504f-5 -1.1272102f-5 0.00012869127 7.400625f-5 0.00014123054 0.0001462146 0.00019910715 -0.00012429574 -9.1846756f-5 -2.1222182f-5 0.00016296368 1.3092267f-5 -0.00012941673 1.6945067f-5 3.7737864f-5 3.2609423f-5 -3.5179783f-5 -6.5790846f-6 7.769471f-5 0.00014544374 -0.00012378013 -1.22246975f-5 -4.6847304f-6; 4.0105737f-5 7.890792f-5 5.6186404f-5 -1.9932277f-5 -8.90999f-5 7.169598f-5 3.8947135f-5 5.7807163f-5 2.9961691f-5 -0.00014686833 -4.5437966f-5 -1.739321f-5 -0.00016153083 5.567002f-5 -6.869584f-5 -3.500158f-5 -1.7669052f-5 1.5762214f-5 -5.4545966f-5 -1.9757255f-5 -0.00012160993 0.00014615344 8.923574f-5 -0.00011275747 -7.54279f-5 -8.485709f-5 -5.243334f-6 2.707063f-5 3.9881026f-5 -2.0964737f-5 -0.00022608018 4.663508f-5; -0.00014489786 -4.838741f-5 8.9061f-5 -2.747173f-5 4.4282708f-5 0.00027671055 -0.00011457691 1.1188249f-5 9.584998f-5 0.00010472063 -8.005667f-5 -7.829992f-5 3.3086493f-5 -2.0379604f-5 -7.816458f-5 5.3994307f-5 4.0836436f-5 -0.00014413356 0.00019223955 0.00014163274 -2.3663497f-5 3.2904776f-5 -4.4640885f-5 -2.3299991f-7 -0.00019072671 5.0884675f-5 2.7633976f-5 -6.631865f-5 7.0587696f-5 0.00016378216 8.0496704f-7 2.5604302f-5; 0.00016200473 -4.2572688f-6 0.00013365624 -3.6429232f-5 0.00017047094 -1.2699376f-5 0.00012303246 -3.8258186f-6 0.000115237875 0.00017262739 -0.00010101302 -2.6358359f-5 -7.702681f-6 -0.00011898278 9.001267f-6 -2.3497952f-5 0.00016548218 1.3118097f-5 3.396845f-5 3.6071484f-5 4.1440144f-5 -3.4662586f-5 5.6744457f-6 1.8615154f-5 0.00017914573 1.3591045f-5 5.681922f-5 -3.347164f-7 0.00010047436 4.7402296f-5 6.0360162f-5 -0.00011285786; -5.8669117f-5 7.683584f-5 6.5438406f-5 -2.0907439f-5 7.7655706f-5 -1.6613378f-5 5.5328903f-5 -0.00016739014 -8.380938f-5 -1.6409851f-5 3.5770274f-5 -4.797582f-5 2.881063f-5 -1.563059f-5 -3.4882312f-5 4.433782f-5 0.00011308936 0.000120888486 0.00022815623 -0.00010928865 1.12605685f-5 -1.2580669f-5 -9.969866f-5 0.00014072424 -6.398738f-5 0.0001767047 -0.00016911518 -3.06926f-5 1.9980895f-5 -7.488415f-5 0.00010655782 0.00013555727; -4.4626853f-5 -0.00010627718 8.560365f-5 1.1342578f-5 -0.00013106612 -0.00011905969 -0.000114954004 3.0859654f-5 0.00015525761 -8.3997744f-5 0.00017239964 3.2744865f-5 -8.5827545f-5 9.193525f-5 4.0150044f-5 8.061017f-5 9.093012f-5 1.1358443f-5 8.304305f-5 0.00010015375 3.6357158f-6 1.3751867f-5 -8.724163f-5 9.211325f-5 0.00013793439 -9.044162f-5 -8.656742f-5 8.643633f-5 -0.00010318596 6.1671117f-6 -9.761192f-7 4.561751f-5; 0.00019377081 0.00011719371 3.931117f-5 8.101116f-5 1.0367315f-5 3.7334124f-5 2.8021339f-5 -9.7242584f-5 6.670219f-5 2.0293988f-5 -7.1147515f-5 7.3688236f-5 6.3477426f-5 -0.00033489786 6.690717f-5 5.5857425f-5 -2.4221998f-5 2.191915f-5 -3.9854163f-5 -7.719074f-6 1.8616005f-5 -0.00025097618 -0.0001372771 -2.6956328f-5 -3.7682265f-5 3.134492f-5 2.3503024f-5 8.896547f-5 -2.0897849f-5 -3.9253555f-5 0.00010828356 8.337603f-5; 4.955869f-5 -2.8548022f-6 0.00020120152 -3.6440855f-5 -9.22207f-5 -0.00014626014 1.6780441f-5 0.00011507804 -6.297668f-5 -0.00020066842 -6.553864f-5 -0.000117798714 -0.00014656725 7.714669f-6 4.5860066f-5 -7.435406f-5 -3.8495804f-5 9.302837f-6 0.00015242904 -0.00016789681 -1.7262337f-5 0.0001991794 -1.4479837f-5 1.9428573f-5 8.3308536f-5 0.00010955091 -7.541987f-5 -4.742375f-5 -6.356461f-5 2.8766543f-5 -8.893696f-5 -3.6045843f-5; -7.7451456f-5 -3.227229f-5 -6.771851f-5 9.157157f-6 -1.4514369f-5 -8.0797756f-5 0.000102338134 7.082497f-5 1.5830185f-5 0.00010741749 -0.00011815878 0.00012318867 -0.00010529716 0.00011841932 6.7970395f-5 9.0781534f-5 -2.210003f-5 -0.00016057806 -3.7051406f-5 0.00014816379 4.550489f-5 -0.0001294765 1.1538051f-5 -6.1815874f-5 2.750215f-5 -0.000102259735 -0.0001763168 -4.0174757f-5 -5.8011225f-5 -9.670139f-5 2.5336703f-6 0.00011358357; 0.000116042545 9.432236f-6 -0.00012336925 5.139279f-5 -2.9819319f-5 0.00011323454 -2.8311471f-5 2.006472f-5 -8.679535f-5 -0.00014086645 0.00015140366 3.860445f-5 -7.9733305f-5 0.000100140795 6.166858f-5 -9.0633635f-5 -7.1030365f-5 1.0360376f-5 5.364328f-5 7.409664f-5 0.00010066706 -4.5460845f-5 -4.1171188f-5 2.791395f-5 7.743995f-5 3.1501542f-5 -4.329484f-6 2.4946246f-5 0.00013919288 -3.7270034f-5 6.808712f-5 -9.9331206f-5; -0.0002274771 -8.7060806f-5 2.1914404f-5 -1.0295335f-5 -5.1190447f-5 4.519846f-5 0.0003013862 -0.00012218456 -2.6835813f-5 4.5780656f-5 7.634814f-5 6.182449f-8 -3.2750418f-6 0.00018880486 5.6170764f-5 -1.9875815f-5 -1.6453618f-5 0.00012165549 0.0001401819 0.00015234502 -4.430962f-5 -6.969981f-5 -2.7094498f-5 1.7838942f-5 8.128057f-6 -0.00012878588 4.4352786f-5 8.2823535f-5 -0.00012609281 1.37832385f-5 -4.2221167f-5 -1.7101947f-5; 0.00012324889 -1.6746822f-5 0.00021113527 4.363241f-5 2.888318f-5 0.00011613313 9.616138f-5 0.00016876188 -2.60486f-5 3.423142f-5 -0.000101975624 -0.00011050251 -5.810649f-5 -8.594321f-5 2.4218902f-5 3.2644657f-5 3.2639062f-5 -4.8350223f-5 1.01572805f-5 0.0001753669 0.00011751083 -7.59947f-5 -0.00020620727 7.4477866f-5 -8.509406f-6 -4.0822782f-5 0.00014426667 0.0001096401 -0.00014902455 0.00020216557 0.00015462148 5.587573f-5; 0.00013588624 0.00010871522 0.00016479277 6.841661f-5 -4.7640497f-5 6.993527f-5 5.65105f-5 7.8857665f-5 0.000105054874 -8.873856f-5 3.6321933f-5 -7.5984484f-5 3.3452157f-8 -5.1688894f-5 3.183572f-5 3.4892582f-5 2.1341375f-5 3.9783095f-5 8.0869206f-5 8.6559194f-5 0.0001303545 -0.00011733403 2.8707109f-5 7.500394f-5 3.121802f-5 -0.00021507409 0.00015176629 -5.6926733f-6 0.00013747522 3.994858f-5 -5.8691476f-5 -2.1026062f-5; -1.7768009f-5 2.8385475f-5 -4.829341f-5 -1.992897f-5 0.00018538523 7.280166f-5 -0.00010529646 -0.000105730716 -9.4199524f-5 -5.3847183f-5 -0.00010413731 5.0509818f-5 -5.6310015f-5 -0.00014136327 1.0767562f-5 6.810316f-6 -2.3922626f-5 3.280594f-5 -0.000103293685 -8.8608185f-6 0.00027247673 0.00020853069 -3.6923928f-5 -1.1903732f-5 -0.00014125717 -1.0928596f-5 -4.3514177f-5 -9.1059796f-5 -2.3297862f-5 7.3194395f-5 -3.5471912f-5 -0.00020315486; -0.00017882275 0.00020139111 4.1459694f-5 -7.622239f-5 8.910217f-5 -3.6916106f-5 -0.00017392579 0.00019908512 7.390891f-5 -0.00019941335 4.2082797f-6 8.226073f-5 -1.777443f-5 -0.0001278739 -0.00016366481 8.189903f-5 -5.253302f-5 -7.191562f-5 9.163165f-5 -1.6263284f-5 0.00014502814 -6.7518085f-5 1.2916156f-5 3.6204896f-5 -6.540441f-5 -4.5406432f-5 2.0972175f-5 1.3095974f-5 6.2322855f-5 -1.9974603f-5 9.9210076f-5 -8.61916f-5; 0.00016936137 3.656075f-5 -4.1919284f-5 -0.00011885561 9.655207f-5 -0.00011219928 9.556938f-6 0.00015204167 4.7562204f-5 1.04764185f-5 -4.701078f-5 -2.5948735f-5 -0.00016768412 -3.0788997f-5 -0.00011574975 -2.0424206f-5 -2.9971212f-5 -7.571733f-5 6.0582508f-5 -0.00017293001 4.316896f-5 -6.2716754f-5 -0.00015405833 -0.00012047803 0.00012028224 0.00012838787 -3.7967355f-5 -8.2451f-5 0.00012896184 1.3643793f-6 6.431056f-5 1.633576f-5; 0.00010003285 3.727209f-5 -9.138933f-5 -5.2181313f-5 8.310358f-5 6.492868f-5 4.583012f-5 -4.2186148f-5 -0.0001651987 -9.550802f-5 8.766806f-5 7.3156584f-6 0.00034057506 6.9071643f-6 -4.0687402f-5 1.007263f-5 7.1740087f-6 -6.8162866f-5 -4.6396948f-5 0.00018629656 -5.191858f-5 9.014063f-5 0.00011147348 8.9177614f-5 0.00016186234 5.4132466f-5 -8.3351944f-5 3.566276f-5 -6.852858f-5 -4.117357f-5 -0.0001346227 -5.9676917f-5; 0.00010684568 0.0001348928 -5.535222f-5 5.0007086f-5 5.601655f-5 -0.00018474099 0.00011968143 0.00014316013 7.836449f-5 -7.245381f-5 -3.9829647f-6 0.00010696729 3.46695f-5 -5.9361177f-5 5.6604385f-5 -7.145486f-5 -0.00016522466 2.4228397f-5 -3.8370334f-5 -1.8790186f-5 0.00016033575 4.777266f-5 2.2544364f-5 -8.133016f-5 2.3503175f-5 -4.137967f-6 -6.0003404f-5 -0.00013196132 -1.4418256f-5 5.6730183f-5 0.00017289004 4.681338f-5], bias = Float32[-0.02359427, -0.052039366, -0.05251706, 0.07375572, -0.12659393, 0.054295257, 0.123188265, 0.029319182, -0.08083024, 0.056631584, 0.10447668, -0.021583064, 0.025179455, 0.124386944, 0.097335786, 0.0443447, -0.048419334, -0.13166347, -0.049825836, 0.1398714, 0.0072322516, 0.050686706, 0.14078636, -0.16508651, 0.080786385, -0.16996355, 0.08915799, -0.09229102, -0.07631192, 0.056342896, -0.091527864, -0.031157775]), layer_4 = (weight = Float32[0.00015599465 3.311973f-5 -6.510345f-5 6.208176f-5 -8.184647f-5 0.00018379104 -1.2308261f-5 2.1719337f-5 -4.4323846f-5 0.00012864801 -5.2148993f-5 -0.00013724022 2.0714178f-6 4.539056f-5 -3.2446214f-6 6.4422486f-5 -3.6855483f-5 0.00015111189 4.5487304f-6 9.539115f-5 0.00018201389 0.00010909794 -6.182839f-5 3.2762935f-5 -3.2461478f-5 8.307005f-5 -0.00015180303 -7.5888565f-5 0.00025877633 -7.149973f-5 -6.582278f-5 -3.6643312f-6; 2.7832424f-5 1.2609855f-5 -0.00018753567 8.129147f-5 6.262537f-5 5.23734f-5 2.79057f-5 -0.00012650469 -0.00017288668 3.900047f-5 -0.00013039056 -7.566465f-5 0.000198026 1.1299888f-5 6.2153727f-6 4.175553f-5 -0.0001247024 5.310171f-5 -8.300113f-5 -1.5673662f-5 6.919494f-5 -0.00013503736 0.00014921595 -2.7103813f-5 -2.3011971f-5 6.960701f-5 -0.00019466836 -6.7567176f-5 3.585768f-5 9.058966f-5 3.369592f-5 7.203921f-5], bias = Float32[0.12521654, -0.03470127])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))Similar to most DL frameworks, Lux defaults to using Float32, however, in this case we need Float64
const params = ComponentArray{Float64}(ps)
const nn_model = StatefulLuxLayer{true}(nn, nothing, st)StatefulLuxLayer{true}(
Chain(
layer_1 = WrappedFunction(Base.Fix1{typeof(broadcast), typeof(cos)}(broadcast, cos)),
layer_2 = Dense(1 => 32, cos), # 64 parameters
layer_3 = Dense(32 => 32, cos), # 1_056 parameters
layer_4 = Dense(32 => 2), # 66 parameters
),
) # Total: 1_186 parameters,
# plus 0 states.Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses
where,
function ODE_model(u, nn_params, t)
χ, ϕ = u
p, M, e = ode_model_params
# In this example we know that `st` is am empty NamedTuple hence we can safely ignore
# it, however, in general, we should use `st` to store the state of the neural network.
y = 1 .+ nn_model([first(u)], nn_params)
numer = (1 + e * cos(χ))^2
denom = M * (p^(3 / 2))
χ̇ = (numer / denom) * y[1]
ϕ̇ = (numer / denom) * y[2]
return [χ̇, ϕ̇]
endODE_model (generic function with 1 method)Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)
axislegend(ax, [[l1, s1], [l2, s2]],
["Waveform Data", "Waveform Neural Net (Untrained)"]; position=:lb)
fig
endSetting Up for Training the Neural Network
Next, we define the objective (loss) function to be minimized when training the neural differential equations.
const mseloss = MSELoss()
function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
return mseloss(waveform, pred_waveform), pred_waveform
endloss (generic function with 1 method)Warmup the loss function
loss(params)(0.002137351739783851, [-0.025341251056282238, -0.024483655289630642, -0.023626059522978797, -0.022169759063213425, -0.0200727072592091, -0.017272572921515905, -0.013684813887055345, -0.009196332881704029, -0.0036653483708941666, 0.003081841225657272, 0.01123748066717486, 0.02096563806376753, 0.032267821431911224, 0.04457288485837422, 0.05568095331795928, 0.059294424563592366, 0.04199095095538732, -0.008548966885104176, -0.06924082233396626, -0.08907890898080677, -0.06591948737279346, -0.03270734728574879, -0.006648039883310758, 0.01037301871866636, 0.02057758367972494, 0.026205100404707384, 0.02882389163795062, 0.029449916052292788, 0.028731806230943402, 0.02708603951120638, 0.024784753073863096, 0.022008264434260003, 0.018878035787344538, 0.015477151661140219, 0.01186240916396324, 0.008075797696544982, 0.004146700817018147, 9.927929273848379e-5, -0.004044653414001314, -0.008261729254252578, -0.012521868423541192, -0.016782677127553883, -0.02097690797721044, -0.025001913250386507, -0.028685692252496797, -0.03174508864972314, -0.033700866518636276, -0.03372830633843011, -0.03039298010492645, -0.02122922027846091, -0.002371700973383941, 0.03022212953223602, 0.07235101251709739, 0.0951378557132959, 0.06549743002482526, 0.007694004290600649, -0.032806649396347615, -0.04799191927064713, -0.048399439283123895, -0.04251099576734531, -0.03453546197904572, -0.026283274773962417, -0.018472133778151333, -0.011354709052900916, -0.004986305341632799, 0.0006541118031884969, 0.0056137123034707016, 0.009943805585527125, 0.013692199056176906, 0.016890454203547595, 0.0195701454254682, 0.02173880376501066, 0.023401746696915863, 0.024542573467757692, 0.025134110630738976, 0.02512866872421487, 0.02445429297432638, 0.023014487121347346, 0.0206666871496101, 0.017229434551175413, 0.012440479688446872, 0.005968640006305539, -0.002634129195479544, -0.013881766188529083, -0.02822681820276043, -0.04546788717728693, -0.0628383649852316, -0.07045999952451935, -0.04811317306392846, 0.01213959988404005, 0.06832835594129771, 0.07983506538732416, 0.06063440090441558, 0.03565347319794071, 0.014908615660010891, -0.00023151749527095496, -0.010732689722888636, -0.017786922178374548, -0.022325666045694418, -0.025028984786112783, -0.026358113742949545, -0.026643594495338124, -0.02610898554591908, -0.024915283356239788, -0.023174977738082525, -0.020965186782141994, -0.018343916467442687, -0.015343531670078932, -0.011993556177481874, -0.008301095353903656, -0.0042805813310892865, 7.304967824820251e-5, 0.0047491881414073065, 0.00974829925103696, 0.01504046713613257, 0.020579829210096844, 0.026243443824709136, 0.031796936425464085, 0.03675132577549929, 0.04013491857853305, 0.040003804775680246, 0.03258748561944628, 0.011479607402788944, -0.02942039538306565, -0.07918703908797897, -0.09580160911608461, -0.05985390896947481, -0.010215368135894146, 0.022279356457670967, 0.03683537212762512, 0.04064745776648472, 0.03901759182098847, 0.034833056890793565, 0.029557787628993668, 0.023925671979035987, 0.01829697061397687, 0.012851089240882007, 0.007680941189855315, 0.0028122908356850913, -0.0017319931253839901, -0.005951258659424314, -0.009835133310210445, -0.013391652973319939, -0.01660305976395776, -0.019454718289474374, -0.021913960469233038, -0.02394353879886342, -0.02547504044962774, -0.026417628588355908, -0.026644026961287287, -0.025966909171447575, -0.02412445766917708, -0.02073655233875695, -0.01526744233489599, -0.006926683867830456, 0.005331352965984618, 0.02272225812912307, 0.0456748138142958, 0.07034411471894865, 0.08054524900557788, 0.04958261617674084, -0.015631438152847708, -0.0622811686235332, -0.0695948487412667, -0.05606796701426696, -0.03812096726972461, -0.021962248325012312, -0.008926341115141118, 0.0011615810113676512, 0.008832221462585008, 0.01456812011639058, 0.018770157465166113, 0.021734289220769094, 0.023695059620580515, 0.024810896665636073, 0.02520420277541696, 0.024961334647104635, 0.02414380280502607, 0.022789812655394703, 0.020920391752898757, 0.018548369062368005, 0.015659723855054243, 0.01224103639991367, 0.008258659350918351, 0.0036811555476033864, -0.0015540889023152248, -0.0074897931581578365, -0.014170436926158103, -0.021588109310972697, -0.02963741914975562, -0.03791561294920878, -0.045361397337332694, -0.04938917751177242, -0.04422474997357939, -0.019600861328998572, 0.031282388182182146, 0.08401296749475838, 0.09096079055983348, 0.05503396014840028, 0.015073831693789027, -0.011559676491285377, -0.025885522004740615, -0.03227474167281411, -0.03400874012367103, -0.033064895530027893, -0.03058833298299253, -0.027234400475640644, -0.02338802409523523, -0.019283905271477764, -0.015054187345523528, -0.010789813966146642, -0.0065424487464577534, -0.0023571783735550048, 0.0017517421213877735, 0.005753875544493212, 0.00963067939123456, 0.0133490741172626, 0.016886042854472278, 0.020187273696492036, 0.023188384349069294, 0.0257905247832887, 0.0278547028459276, 0.029164912870076474, 0.029396993609508724, 0.028056820621380832, 0.02435670000310105, 0.017081213224743944, 0.004374737054343884, -0.016168311799445657, -0.04596385411265935, -0.07847619973148896, -0.08760104225600784, -0.04665612582430517, 0.016004497080287824, 0.05271739183065759, 0.05937264081095467, 0.05151730892884256, 0.03948143861267448, 0.027462016128591067, 0.016789536969604016, 0.007729363321312879, 0.0001751442335328554, -0.006050039930187012, -0.011137204597302795, -0.015238235760710615, -0.01849584723836924, -0.02100152639036414, -0.02283340164384959, -0.02404191591306771, -0.025250430182286024])Now let us define a callback function to store the loss over time
const losses = Float64[]
function callback(θ, l, pred_waveform)
push!(losses, l)
@printf "Training %10s Iteration: %5d %10s Loss: %.10f\n" "" length(losses) "" l
return false
endcallback (generic function with 1 method)Training the Neural Network
Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
optprob, BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
callback, maxiters=1000)retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [3.409750570426695e-5; -5.335412788554095e-5; -0.00023229134967550635; 2.6896035706158727e-5; 6.476599810412154e-5; 6.789436156395823e-5; 0.0001762003666954115; 6.102080442360602e-5; 0.00010646814189385623; 0.00021610218391288072; 4.502246520132758e-5; 9.842102736001834e-5; 3.6622819607146084e-5; -3.9151516830315813e-5; -4.410462861414999e-5; -1.6729940398363397e-5; -8.749419066589326e-5; -2.682280319277197e-5; -0.0001094052495318465; -0.00014563721197191626; 6.830096390331164e-5; 3.95687748095952e-5; -0.0001073439052561298; 9.670897998148575e-5; 8.210131636587903e-5; 3.821753489319235e-5; 0.00011002541577909142; 0.00012195475574117154; 5.61381011721096e-6; -2.107597902067937e-5; -0.00013300184218678623; 8.509586223226506e-6;;], bias = [-0.8554672002792358, 0.43835997581481934, 0.18090736865997314, 0.17270290851593018, 0.7747138738632202, -0.537463903427124, 0.7339503765106201, -0.6907318830490112, -0.9903888702392578, -0.6358287334442139, -0.9414416551589966, -0.13805973529815674, -0.14056086540222168, -0.08396017551422119, -0.12377333641052246, -0.2723672389984131, 0.4750349521636963, 0.8160997629165649, -0.5674476623535156, 0.12224745750427246, 0.7362548112869263, 0.9810594320297241, -0.7524651288986206, -0.9390366077423096, 0.1979461908340454, 0.018231630325317383, -0.8508137464523315, -0.9459489583969116, -0.0008701086044311523, -0.692562460899353, 0.19172227382659912, 0.3651496171951294]), layer_3 = (weight = [2.2019934476702474e-5 9.105469507630914e-5 2.672578557394445e-5 0.00011013531184289604 4.211144187138416e-5 5.407473508967087e-5 3.087161167059094e-5 -1.3899451005272567e-5 0.00016610106104053557 0.00010314327664673328 2.4404229407082312e-5 8.36261642689351e-6 0.00011248399823671207 -2.5824376734817633e-6 -1.1004563020833302e-5 -7.22534314263612e-5 2.842518915713299e-5 3.1412342650583014e-5 3.6345958505989984e-5 0.0002244042989332229 -0.00010714833479141816 -3.981481131631881e-5 -0.0001521389203844592 0.0001781196187948808 9.190930722979829e-5 -7.363268196058925e-6 2.9204233214841224e-5 -6.39914142084308e-5 5.9155147027922794e-5 7.290606299648061e-5 3.936544089810923e-5 2.496785600669682e-5; -3.631928120739758e-5 -2.0986639356124215e-5 8.745924424147233e-5 7.878740143496543e-5 0.0001431980199413374 0.00011024868581444025 -8.994626114144921e-5 -2.0634108295780607e-5 -4.9877187848323956e-5 -1.3193887298257323e-6 -8.412228635279462e-5 -0.0001907964760903269 -4.862644345848821e-5 -0.0002169862564187497 -0.00015282419917639345 -7.30842657503672e-5 -7.155204366426915e-5 -0.00012661860091611743 -7.700318383285776e-5 8.508758764946833e-5 5.8536748838378116e-5 4.136842107982375e-5 9.481730558036361e-6 9.315815077570733e-6 1.3012886483920738e-5 -8.585360046708956e-5 0.00011791313590947539 9.95405571302399e-5 -0.0002253167040180415 3.1915544695948483e-6 -0.00014858407666906714 -1.7624666725168936e-5; 1.3660184777108952e-5 5.960712587693706e-5 2.894238059525378e-5 2.5221075702575035e-5 5.174574471311644e-5 0.00010662020940799266 8.439322846243158e-5 -1.8862361685023643e-5 -4.3768177420133725e-5 1.1804941095761023e-5 3.637802001321688e-5 1.0617779480526224e-5 -5.779686398454942e-5 -3.227528213756159e-5 -0.00022527061810251325 0.0001294924586545676 4.300949876778759e-5 5.6790318922139704e-5 -5.6455250160070136e-5 9.609929838916287e-5 2.0569172193063423e-5 6.456080154748634e-5 6.73849499435164e-5 -6.91259847371839e-5 0.00010254904191242531 2.938227407867089e-5 1.9133376554236747e-5 -6.114810275903437e-6 9.075788693735376e-5 -0.00016537742340005934 -2.444118945277296e-5 0.00010570185258984566; 5.119317938806489e-5 7.461401401087642e-5 -0.00022950781567487866 0.00015979072486516088 -0.00015588448150083423 -2.759972085186746e-5 5.4033345804782584e-5 -7.009402179392055e-5 0.0001692324149189517 -0.00019113793678116053 1.1821684893220663e-5 0.0001601384428795427 0.00015681746299378574 -0.00017785119416657835 2.422122452117037e-5 -0.00021886860486119986 5.923616481595673e-5 -8.459040691377595e-5 -3.581713099265471e-5 0.0001898818591143936 -3.6216777516528964e-5 0.00017793814186006784 -8.984872692963108e-6 7.182870467659086e-5 5.814036921947263e-5 -0.00012459892604965717 0.00016681945999152958 -5.725595110561699e-5 -1.5240711945807561e-5 -6.921707245055586e-5 0.00013018754543736577 -0.0002219159359810874; 1.1235368219786324e-5 1.3497149666363839e-5 9.251108713215217e-5 -4.674421506933868e-5 0.0001173194104922004 0.0002010657626669854 -5.490717740030959e-5 1.833624264691025e-5 8.49107964313589e-5 4.769928182213334e-6 -2.8853319236077368e-5 -6.0636313719442114e-5 -2.9623062800965272e-5 7.029977859929204e-5 7.327495404751971e-5 4.299528882256709e-5 -8.130276910378598e-6 -2.446531470923219e-5 7.23186312825419e-5 -0.00015931620146147907 9.56386502366513e-5 0.00010066628601634875 3.894710971508175e-5 -6.1212238506414e-5 -3.599880074034445e-5 -3.7958267057547346e-5 -5.402654642239213e-5 -7.773507240926847e-5 -0.00012637945474125445 7.280933641595766e-5 2.0582434444804676e-5 0.00023414351744577289; -3.557337913662195e-5 2.7564410629565828e-5 -0.00010337385174352676 0.00010487824329175055 0.00013306326582096517 -2.012031018239213e-6 5.796192635898478e-5 -2.8401465897331946e-5 -4.255977546563372e-5 8.563348501411383e-7 -6.821587885497138e-5 -7.619395182700828e-5 3.0819650419289246e-5 -4.792854451807216e-5 -8.463047561235726e-5 4.951489609084092e-5 4.81270435557235e-5 -0.0001120308879762888 7.230371556943282e-5 -3.549687608028762e-5 -4.965796324540861e-5 -6.627291440963745e-5 -2.438108276692219e-5 3.882192913806648e-7 -3.028658466064371e-5 -6.911711534485221e-5 0.00015253774472512305 6.790864426875487e-5 -8.01471687736921e-5 -0.00012482634338084608 -0.00017997762188315392 0.00013130062143318355; 0.00010582964750938118 0.000220138332224451 0.00010590407327981666 -2.1493453459697776e-5 7.832588016754016e-5 -5.887639781576581e-5 8.346209506271407e-5 3.481541352812201e-5 -0.00013060007768217474 -3.34938958985731e-5 3.627718615462072e-5 0.00014513509813696146 3.1081704946700484e-5 -1.7285965441260487e-5 -2.7875146770384163e-5 -0.00012055417755618691 -7.62518320698291e-5 -6.581398338312283e-5 0.00012290834274608642 4.293251913622953e-5 0.00015772884944453835 0.00011260274186497554 -5.641372626996599e-5 0.00013965749531053007 0.000113714013423305 -0.00010072881559608504 6.740591197740287e-5 -0.0001860295596998185 9.821791900321841e-5 -6.462408055085689e-5 1.3134823348082136e-5 9.052437235368416e-5; 0.00014229639782570302 3.785945591516793e-5 -4.123769758734852e-5 -9.11457245820202e-5 -4.1365798097103834e-5 5.362163210520521e-5 -3.393009683350101e-5 3.5047934943577275e-5 1.8136273638447165e-6 -7.302139420062304e-5 8.889237506082281e-5 0.00024709413992241025 -0.00014537987590301782 0.00013519827916752547 -0.0001159932289738208 6.340861727949232e-5 -1.4834025023446884e-5 9.397284156875685e-5 9.060985030373558e-5 1.9739949493668973e-5 -4.889204501523636e-5 8.783928205957636e-5 2.9862963856430724e-5 3.984573413617909e-5 -3.414289312786423e-5 -7.437027670675889e-5 0.0001221483835252002 -1.0997738172591198e-5 7.017819007160142e-5 -0.00010020639456342906 2.3358577891485766e-5 6.236899935174733e-6; 4.64577206003014e-5 6.606425813515671e-6 6.409629713743925e-5 -4.756190537591465e-5 1.7183832824230194e-5 -4.9554168072063476e-5 1.0981941159116104e-5 -8.405133849009871e-5 0.00010131720773642883 8.581134170526639e-5 5.402123406383907e-6 4.038698898511939e-5 -5.432782472780673e-6 6.883496098453179e-5 6.614168523810804e-5 -0.0001296888804063201 -9.541193139739335e-5 1.0767668754851911e-5 8.868139047990553e-6 -0.00011689990060403943 0.00010504195233806968 -2.429191408737097e-5 8.92857315193396e-6 -0.00013418438902590424 -1.8854545487556607e-5 -2.6927686121780425e-5 2.45901828748174e-5 2.07070515898522e-5 5.44172253285069e-5 0.00019555012113414705 3.735105929081328e-5 -7.583571277791634e-5; -9.570881957188249e-6 0.00010471379209775478 7.713494483141403e-7 0.00023507488367613405 -0.000298128230497241 -0.00012194034934509546 -6.315030623227358e-5 -0.0002217157743871212 0.00024056507390923798 0.00011514026118675247 8.989515481516719e-5 -7.951035513542593e-5 3.929326339857653e-5 -2.6734100174508058e-5 6.635396857745945e-5 -5.767941547674127e-5 -0.00011449221346992999 -0.00010660111729521304 0.00010526114783715457 -0.0001008436011034064 -0.00023731838155072182 3.7925019569229335e-5 -0.0001880197669379413 0.00010788175131892785 -6.604657392017543e-5 -6.473218672908843e-5 -5.0701175496215e-5 -0.00010545324767008424 0.00012679708015639335 4.844202339882031e-5 5.652969775837846e-5 8.838227950036526e-5; -2.1435353119159117e-5 -0.00010857757297344506 8.72379241627641e-5 2.178138674935326e-5 -0.00011873425683006644 1.3891183698433451e-5 -0.0001686171890469268 -0.0002428300940664485 -2.774911990854889e-5 -0.00018039006681647152 -0.00012225729005876929 5.791188232251443e-5 2.9661157896043733e-5 3.4046726796077564e-5 0.00011700436152750626 -0.00015084077313076705 6.674859469057992e-5 8.944400906329975e-5 0.00013883927022106946 -0.0001700252469163388 -7.893057772889733e-5 -0.00014329487748909742 0.0001439576008124277 4.365279892226681e-5 0.00013312466035131365 4.9454858526587486e-5 -1.2009610145469196e-5 -0.00011458102380856872 9.358418174088001e-5 -8.82563108461909e-5 9.704816875455435e-6 0.00024231270072050393; -2.5632354663684964e-5 -7.205962901934981e-5 -4.145114871789701e-6 8.561345748603344e-5 -0.00012784426508005708 -5.029623935115524e-6 7.240428385557607e-5 -2.7754229449783452e-5 -0.00015231134602800012 -0.0001937497581820935 1.1621168596320786e-5 -6.149663386167958e-5 -9.140653855865821e-5 -4.679382618633099e-5 -5.278977187117562e-5 3.919091250281781e-5 9.451057121623307e-5 -3.329732862766832e-5 9.497339488007128e-5 -5.959437839919701e-5 -4.589289164869115e-5 6.079220838728361e-5 -6.26591790933162e-5 0.00013474415754899383 -6.913077231729403e-5 -0.0001392098201904446 7.389216534647858e-6 7.96801905380562e-5 0.00012063571193721145 2.2644035198027268e-5 0.00020402736845426261 0.0001663088332861662; 0.00017571247008163482 -0.00011754441220546141 9.47530770645244e-6 -5.133147624292178e-6 -0.00014763386570848525 5.277132731862366e-5 -0.00011003890540450811 2.4373508495045826e-5 3.6557848943630233e-5 2.3612670702277683e-5 7.36299843993038e-5 -1.1933359928661957e-5 -7.894892041804269e-5 7.443757931469008e-5 -1.5575551515212283e-5 4.028984403703362e-5 1.2749095731123816e-5 -9.327548468718305e-5 4.239876216161065e-5 -6.99628799338825e-5 -1.895692548714578e-5 -1.1643972356978338e-5 2.0381643480504863e-5 6.480221054516733e-5 -8.077422535279766e-6 -7.033826477709226e-6 0.00011463044211268425 9.971537656383589e-5 -6.238808418856934e-5 -3.15888537443243e-5 -0.00020152523939032108 0.00013340468285605311; 0.0001083228926290758 0.0001360651949653402 -7.15825444785878e-5 -0.00015130820975173265 -2.417677933408413e-5 0.00012202212383272126 2.9229979190859012e-5 -3.1851708627073094e-5 -3.3238520700251684e-5 -6.866923740744824e-6 -0.00016792122914921492 1.0131540875590872e-5 0.00019030591647606343 0.00020995100203435868 1.8427459508529864e-5 5.809492358821444e-5 -0.00013603751722257584 9.613665315555409e-5 3.5310335078975186e-5 0.00012460372818168253 6.034756006556563e-5 -1.4836867194389924e-5 -9.724632946017664e-6 1.2549166058306582e-5 0.0001339084847131744 -3.129617834929377e-5 9.00206869118847e-5 6.968469097046182e-5 -1.1891258509422187e-5 3.9419708627974615e-5 0.00014710085815750062 7.194032514235005e-5; -0.00012792303459718823 -8.342724322574213e-5 -9.058794967131689e-5 -0.00013350590597838163 -0.00012840097770094872 -5.501897612703033e-5 -6.58335120533593e-5 0.00022190222807694227 8.32714340504026e-6 7.824950444046408e-5 -1.127210180129623e-5 0.0001286912738578394 7.400625327136368e-5 0.0001412305427948013 0.0001462146028643474 0.00019910714763682336 -0.00012429573689587414 -9.184675582218915e-5 -2.122218211297877e-5 0.0001629636826692149 1.309226718149148e-5 -0.00012941673048771918 1.6945066818152554e-5 3.77378637494985e-5 3.260942321503535e-5 -3.517978257150389e-5 -6.579084583790973e-6 7.769471267238259e-5 0.00014544374425895512 -0.00012378013343550265 -1.2224697456986178e-5 -4.6847303565300535e-6; 4.010573684354313e-5 7.890792039688677e-5 5.618640352622606e-5 -1.9932276700274087e-5 -8.909989992389455e-5 7.169597665779293e-5 3.89471351809334e-5 5.7807163102552295e-5 2.996169132529758e-5 -0.00014686833310406655 -4.5437966036843136e-5 -1.7393209418514743e-5 -0.00016153082833625376 5.567002153838985e-5 -6.869583739899099e-5 -3.500157981761731e-5 -1.766905188560486e-5 1.5762214388814755e-5 -5.4545966122532263e-5 -1.975725535885431e-5 -0.00012160993355792016 0.00014615344116464257 8.923574205255136e-5 -0.00011275747237959877 -7.542790262959898e-5 -8.485709258820862e-5 -5.24333381690667e-6 2.7070629585068673e-5 3.9881026168586686e-5 -2.096473690471612e-5 -0.000226080184802413 4.6635079343104735e-5; -0.0001448978582629934 -4.838741006096825e-5 8.906100265448913e-5 -2.747172948147636e-5 4.428270767675713e-5 0.00027671054704114795 -0.0001145769128925167 1.1188249118276872e-5 9.584998042555526e-5 0.00010472063149791211 -8.005667041288689e-5 -7.829991955077276e-5 3.308649320388213e-5 -2.0379604393383488e-5 -7.816457946319133e-5 5.399430665420368e-5 4.083643580088392e-5 -0.0001441335625713691 0.00019223954586777836 0.00014163274317979813 -2.366349690419156e-5 3.29047761624679e-5 -4.464088488020934e-5 -2.329999091443824e-7 -0.00019072671420872211 5.088467514724471e-5 2.7633976060315035e-5 -6.631865107920021e-5 7.058769551804289e-5 0.0001637821551412344 8.049670441323542e-7 2.560430220910348e-5; 0.00016200472600758076 -4.257268756191479e-6 0.0001336562418146059 -3.642923184088431e-5 0.00017047094297595322 -1.2699376384261996e-5 0.00012303245603106916 -3.825818566838279e-6 0.00011523787543410435 0.0001726273912936449 -0.00010101302177645266 -2.6358358809375204e-5 -7.70268070482416e-6 -0.0001189827817142941 9.001267244457267e-6 -2.3497952497564256e-5 0.00016548218263778836 1.311809683102183e-5 3.396844840608537e-5 3.6071483918931335e-5 4.144014383200556e-5 -3.466258567641489e-5 5.67444567423081e-6 1.8615153749124147e-5 0.00017914573254529387 1.3591044989880174e-5 5.681921902578324e-5 -3.347163897160499e-7 0.00010047436080640182 4.7402296331711113e-5 6.0360162024153396e-5 -0.00011285785876680166; -5.866911669727415e-5 7.683583680773154e-5 6.54384057270363e-5 -2.0907438738504425e-5 7.765570626361296e-5 -1.6613377738394774e-5 5.5328902817564085e-5 -0.00016739014245104045 -8.380937651963905e-5 -1.640985101403203e-5 3.577027382561937e-5 -4.7975820052670315e-5 2.881062937376555e-5 -1.5630590496584773e-5 -3.488231232040562e-5 4.433781941770576e-5 0.00011308935791021213 0.0001208884859806858 0.00022815623378846794 -0.0001092886523110792 1.1260568498983048e-5 -1.2580669135786593e-5 -9.96986564132385e-5 0.00014072423800826073 -6.398738332791254e-5 0.00017670470697339624 -0.00016911518468987197 -3.069260128540918e-5 1.998089464905206e-5 -7.48841484892182e-5 0.00010655781807145104 0.00013555727491620928; -4.462685319595039e-5 -0.00010627717711031437 8.560364949516952e-5 1.1342577636241913e-5 -0.00013106611731927842 -0.00011905968858627602 -0.00011495400394778699 3.085965363425203e-5 0.00015525761409662664 -8.399774378631264e-5 0.00017239963926840574 3.2744865166023374e-5 -8.582754526287317e-5 9.193525329465047e-5 4.01500437874347e-5 8.061016706051305e-5 9.093012340599671e-5 1.135844286181964e-5 8.30430508358404e-5 0.00010015375301009044 3.6357157569000265e-6 1.3751867300015874e-5 -8.724162762518972e-5 9.211325232172385e-5 0.00013793438847642392 -9.044162288773805e-5 -8.656742284074426e-5 8.643633191240951e-5 -0.00010318595741409808 6.167111678223591e-6 -9.761191677171155e-7 4.561751120490953e-5; 0.00019377081480342895 0.00011719371104845777 3.931116953026503e-5 8.101115963654593e-5 1.0367314644099679e-5 3.733412449946627e-5 2.8021338948747143e-5 -9.724258416099474e-5 6.670218863291666e-5 2.029398820013739e-5 -7.114751497283578e-5 7.368823571596295e-5 6.347742601064965e-5 -0.0003348978643771261 6.690716691082343e-5 5.585742474067956e-5 -2.4221997591666877e-5 2.191914973082021e-5 -3.985416333307512e-5 -7.719074346823618e-6 1.8616005036165006e-5 -0.00025097618345171213 -0.00013727709301747382 -2.695632792892866e-5 -3.768226451938972e-5 3.134491998935118e-5 2.3503023840021342e-5 8.89654693310149e-5 -2.089784902636893e-5 -3.925355485989712e-5 0.00010828355880221352 8.33760277600959e-5; 4.955869007972069e-5 -2.8548022328322986e-6 0.00020120151748415083 -3.644085518317297e-5 -9.22206963878125e-5 -0.00014626013580709696 1.678044100117404e-5 0.00011507803719723597 -6.297668005572632e-5 -0.00020066842262167484 -6.553863931912929e-5 -0.00011779871420003474 -0.00014656725397799164 7.71466875448823e-6 4.586006616591476e-5 -7.435405859723687e-5 -3.8495803892146796e-5 9.302836588176433e-6 0.00015242904191836715 -0.00016789681103546172 -1.7262336768908426e-5 0.0001991793978959322 -1.4479836863756645e-5 1.942857306858059e-5 8.330853597726673e-5 0.00010955090692732483 -7.541986997239292e-5 -4.742374949273653e-5 -6.35646065347828e-5 2.876654252759181e-5 -8.893696212908253e-5 -3.604584344429895e-5; -7.74514555814676e-5 -3.227229171898216e-5 -6.771850894438103e-5 9.157157364825252e-6 -1.451436855859356e-5 -8.079775579972193e-5 0.00010233813372906297 7.082497177179903e-5 1.5830184565857053e-5 0.00010741748701548204 -0.00011815877951448783 0.00012318867084104568 -0.00010529715655138716 0.00011841931700473651 6.797039532102644e-5 9.078153379959986e-5 -2.210003003710881e-5 -0.00016057805623859167 -3.705140625243075e-5 0.00014816378825344145 4.5504890294978395e-5 -0.00012947649520356208 1.153805078502046e-5 -6.18158737779595e-5 2.750215026026126e-5 -0.00010225973528577015 -0.00017631679656915367 -4.017475657747127e-5 -5.801122460979968e-5 -9.670139115769416e-5 2.5336703401990235e-6 0.00011358357005519792; 0.00011604254541452974 9.432235856365878e-6 -0.00012336924555711448 5.1392791647231206e-5 -2.9819319024682045e-5 0.00011323454236844555 -2.8311471396591514e-5 2.0064719137735665e-5 -8.679534948896617e-5 -0.00014086645387578756 0.0001514036557637155 3.860444849124178e-5 -7.973330502863973e-5 0.00010014079452957958 6.166857929201797e-5 -9.063363540917635e-5 -7.103036477928981e-5 1.0360376109019853e-5 5.364328171708621e-5 7.409664249280468e-5 0.00010066705726785585 -4.5460845285560936e-5 -4.117118805879727e-5 2.7913949452340603e-5 7.743995229247957e-5 3.1501542252954096e-5 -4.329483999754302e-6 2.4946246412582695e-5 0.00013919288176111877 -3.7270034226821736e-5 6.808712350903079e-5 -9.9331206001807e-5; -0.00022747709590476006 -8.706080552656204e-5 2.1914403987466358e-5 -1.0295335414411966e-5 -5.1190447265980765e-5 4.519845970207825e-5 0.00030138620059005916 -0.0001221845595864579 -2.683581260498613e-5 4.578065636451356e-5 7.634813664481044e-5 6.182448686331554e-8 -3.2750417631177697e-6 0.00018880485731642693 5.617076385533437e-5 -1.9875815269188024e-5 -1.6453617718070745e-5 0.00012165548832854256 0.00014018190267961472 0.00015234501915983856 -4.430962144397199e-5 -6.969981041038409e-5 -2.7094498364022e-5 1.7838941857917234e-5 8.12805683381157e-6 -0.00012878587585873902 4.4352786062518135e-5 8.282353519462049e-5 -0.00012609281111508608 1.3783238500764128e-5 -4.222116695018485e-5 -1.710194737825077e-5; 0.00012324888666626066 -1.6746822439017706e-5 0.00021113526599947363 4.363241168903187e-5 2.8883179766125977e-5 0.00011613313108682632 9.616137685952708e-5 0.00016876187874004245 -2.6048599465866573e-5 3.4231419704155996e-5 -0.00010197562369285151 -0.00011050250759581104 -5.810648872284219e-5 -8.594321116106585e-5 2.4218901671702042e-5 3.264465703978203e-5 3.2639061828376725e-5 -4.835022264160216e-5 1.0157280485145748e-5 0.0001753669057507068 0.00011751082638511434 -7.599469972774386e-5 -0.00020620727445930243 7.447786629199982e-5 -8.509406143275555e-6 -4.0822782466420904e-5 0.00014426666893996298 0.0001096401028917171 -0.00014902454859111458 0.0002021655673161149 0.00015462147712241858 5.587573105003685e-5; 0.00013588623551186174 0.00010871521953959018 0.00016479277110192925 6.841660797363147e-5 -4.764049663208425e-5 6.993526767473668e-5 5.651050014421344e-5 7.88576653576456e-5 0.00010505487443879247 -8.873856131685898e-5 3.632193329394795e-5 -7.598448428325355e-5 3.345215660033318e-8 -5.1688894018298015e-5 3.1835719710215926e-5 3.489258233457804e-5 2.134137503162492e-5 3.9783095417078584e-5 8.086920570349321e-5 8.655919373268262e-5 0.00013035449956078082 -0.0001173340278910473 2.8707108867820352e-5 7.500393985537812e-5 3.121802001260221e-5 -0.0002150740911019966 0.0001517662894912064 -5.692673312296392e-6 0.00013747521734330803 3.994857979705557e-5 -5.8691475715022534e-5 -2.1026062313467264e-5; -1.776800854713656e-5 2.8385475161485374e-5 -4.829340832657181e-5 -1.992896977753844e-5 0.00018538522999733686 7.28016602806747e-5 -0.0001052964580594562 -0.00010573071631370112 -9.419952402822673e-5 -5.384718315326609e-5 -0.00010413731070002541 5.050981781096198e-5 -5.631001477013342e-5 -0.00014136327081359923 1.0767562343971804e-5 6.810315881011775e-6 -2.3922626496641897e-5 3.280593955423683e-5 -0.00010329368524253368 -8.860818525135983e-6 0.0002724767255131155 0.0002085306914523244 -3.692392783705145e-5 -1.1903732229257002e-5 -0.0001412571727996692 -1.0928595656878315e-5 -4.351417737780139e-5 -9.105979552259669e-5 -2.329786184418481e-5 7.319439464481547e-5 -3.547191226971336e-5 -0.0002031548647210002; -0.00017882275278680027 0.00020139111438766122 4.145969433011487e-5 -7.622239354532212e-5 8.910217002267018e-5 -3.6916106182616204e-5 -0.00017392578592989594 0.0001990851160371676 7.390890823444352e-5 -0.00019941334903705865 4.208279733575182e-6 8.226073259720579e-5 -1.7774429579731077e-5 -0.00012787389277946204 -0.00016366480849683285 8.18990301922895e-5 -5.253302151686512e-5 -7.19156232662499e-5 9.163164941128343e-5 -1.6263284123851918e-5 0.00014502814156003296 -6.751808541594073e-5 1.2916156265418977e-5 3.6204895877745e-5 -6.540441245306283e-5 -4.5406432036543265e-5 2.0972174752387218e-5 1.3095974281895906e-5 6.232285522855818e-5 -1.9974602764705196e-5 9.921007585944608e-5 -8.619159780209884e-5; 0.00016936137399170548 3.6560748412739486e-5 -4.1919283830793574e-5 -0.00011885561252711341 9.655206667957827e-5 -0.00011219928273931146 9.556937584420666e-6 0.00015204166993498802 4.756220369017683e-5 1.0476418538019061e-5 -4.701078069047071e-5 -2.5948735128622502e-5 -0.00016768412024248391 -3.078899680986069e-5 -0.00011574975360417739 -2.0424206013558432e-5 -2.9971211915835738e-5 -7.571733294753358e-5 6.058250801288523e-5 -0.00017293001292273402 4.3168958654860035e-5 -6.271675374591723e-5 -0.00015405833255499601 -0.00012047802738379687 0.00012028223864035681 0.00012838786642532796 -3.796735472860746e-5 -8.245099888881668e-5 0.00012896183761768043 1.3643792726725223e-6 6.431055953726172e-5 1.63357599376468e-5; 0.00010003284842241555 3.727208968484774e-5 -9.138933091890067e-5 -5.2181312639731914e-5 8.31035795272328e-5 6.492868124041706e-5 4.583011832437478e-5 -4.218614776618779e-5 -0.00016519869677722454 -9.550801769364625e-5 8.766805694904178e-5 7.315658422157867e-6 0.00034057506127282977 6.907164333824767e-6 -4.06874023610726e-5 1.0072630175272934e-5 7.174008715082891e-6 -6.816286622779444e-5 -4.639694816432893e-5 0.00018629655824042857 -5.191858144826256e-5 9.014063107315451e-5 0.0001114734768634662 8.917761442717165e-5 0.0001618623355170712 5.4132466175360605e-5 -8.335194434039295e-5 3.566276063793339e-5 -6.852857768535614e-5 -4.1173570934915915e-5 -0.00013462270726449788 -5.967691686237231e-5; 0.00010684567678254098 0.0001348928053630516 -5.5352218623738736e-5 5.000708551960997e-5 5.6016549933701754e-5 -0.0001847409876063466 0.00011968142644036561 0.00014316012675408274 7.836448639864102e-5 -7.245381129905581e-5 -3.982964699389413e-6 0.00010696728713810444 3.4669501474127173e-5 -5.936117668170482e-5 5.6604385463288054e-5 -7.14548586984165e-5 -0.00016522465739399195 2.422839679638855e-5 -3.837033364106901e-5 -1.8790186004480347e-5 0.00016033575229812413 4.7772660764167085e-5 2.2544363673659973e-5 -8.133015944622457e-5 2.3503174816141836e-5 -4.137967152928468e-6 -6.000340363243595e-5 -0.0001319613220402971 -1.4418255886994302e-5 5.673018313245848e-5 0.0001728900388116017 4.681338032241911e-5], bias = [-0.023594269528985023, -0.05203936621546745, -0.052517060190439224, 0.07375571876764297, -0.1265939325094223, 0.05429525673389435, 0.12318826466798782, 0.02931918203830719, -0.08083023875951767, 0.05663158372044563, 0.1044766828417778, -0.021583063527941704, 0.025179455056786537, 0.12438694387674332, 0.09733578562736511, 0.0443447008728981, -0.04841933399438858, -0.1316634714603424, -0.04982583597302437, 0.13987140357494354, 0.007232251577079296, 0.050686705857515335, 0.14078636467456818, -0.1650865077972412, 0.08078638464212418, -0.1699635535478592, 0.08915799111127853, -0.09229101985692978, -0.07631192356348038, 0.056342896074056625, -0.09152786433696747, -0.031157774850726128]), layer_4 = (weight = [0.0001559946540510282 3.311972977826372e-5 -6.510344974230975e-5 6.20817590970546e-5 -8.184646867448464e-5 0.00018379103858023882 -1.2308260920690373e-5 2.1719337382819504e-5 -4.432384594110772e-5 0.00012864801101386547 -5.214899283600971e-5 -0.00013724021846428514 2.071417839033529e-6 4.5390559535007924e-5 -3.2446214390802197e-6 6.442248559324071e-5 -3.685548290377483e-5 0.00015111188986338675 4.548730430542491e-6 9.539115126244724e-5 0.00018201388593297452 0.00010909794218605384 -6.182838842505589e-5 3.27629350067582e-5 -3.246147753088735e-5 8.307005191454664e-5 -0.00015180303307715803 -7.588856533402577e-5 0.0002587763301562518 -7.149972952902317e-5 -6.582278001587838e-5 -3.664331188701908e-6; 2.783242416626308e-5 1.2609854820766486e-5 -0.00018753566837403923 8.129147317958996e-5 6.262536771828309e-5 5.237340155872516e-5 2.7905700335395522e-5 -0.00012650468852370977 -0.00017288667731918395 3.900047158822417e-5 -0.00013039055920671672 -7.566464773844928e-5 0.0001980259985430166 1.1299887773930095e-5 6.215372650331119e-6 4.175553112872876e-5 -0.00012470240471884608 5.3101710363989696e-5 -8.300112676806748e-5 -1.5673662346671335e-5 6.919493898749352e-5 -0.0001350373640889302 0.0001492159499321133 -2.710381340875756e-5 -2.3011971279629506e-5 6.96070128469728e-5 -0.0001946683623827994 -6.756717630196363e-5 3.5857679904438555e-5 9.058965952135623e-5 3.369592013768852e-5 7.203921268228441e-5], bias = [0.125216543674469, -0.03470126911997795]))Visualizing the Results
Let us now plot the loss over time
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")
lines!(ax, losses; linewidth=4, alpha=0.75)
scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)
fig
endFinally let us visualize the results
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(compute_waveform(
dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)
l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
s3 = scatter!(ax, tsteps, waveform_nn_trained; marker=:circle,
alpha=0.5, strokewidth=2, markersize=12)
axislegend(ax, [[l1, s1], [l2, s2], [l3, s3]],
["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
position=:lb)
fig
endAppendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 128 × AMD EPYC 7502 32-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
JULIA_CPU_THREADS = 16
JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 16
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = LiterateThis page was generated using Literate.jl.