Training a Neural ODE to Model Gravitational Waveforms
This code is adapted from Astroinformatics/ScientificMachineLearning
The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl
Package Imports
using Lux, ComponentArrays, LineSearches, OrdinaryDiffEq, Optimization, OptimizationOptimJL,
Printf, Random, SciMLSensitivity
using CairoMakieDefine some Utility Functions
Tip
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector
function one2two(path, m₁, m₂)
M = m₁ + m₂
r₁ = m₂ / M .* path
r₂ = -m₁ / M .* path
return r₁, r₂
endone2two (generic function with 1 method)Next we define a function to perform the change of variables:
@views function soln2orbit(soln, model_params=nothing)
@assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
if size(soln, 1) == 2
χ = soln[1, :]
ϕ = soln[2, :]
@assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2"
p, M, e = model_params
else
χ = soln[1, :]
ϕ = soln[2, :]
p = soln[3, :]
e = soln[4, :]
end
r = p ./ (1 .+ e .* cos.(χ))
x = r .* cos.(ϕ)
y = r .* sin.(ϕ)
orbit = vcat(x', y')
return orbit
endsoln2orbit (generic function with 2 methods)This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d_dt(v::AbstractVector, dt)
a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
b = (v[3:end] .- v[1:(end - 2)]) / 2
c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
return [a; b; c] / dt
endd_dt (generic function with 1 method)This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d2_dt2(v::AbstractVector, dt)
a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
return [a; b; c] / (dt^2)
endd2_dt2 (generic function with 1 method)Now we define a function to compute the trace-free moment tensor from the orbit
function orbit2tensor(orbit, component, mass=1)
x = orbit[1, :]
y = orbit[2, :]
Ixx = x .^ 2
Iyy = y .^ 2
Ixy = x .* y
trace = Ixx .+ Iyy
if component[1] == 1 && component[2] == 1
tmp = Ixx .- trace ./ 3
elseif component[1] == 2 && component[2] == 2
tmp = Iyy .- trace ./ 3
else
tmp = Ixy
end
return mass .* tmp
end
function h_22_quadrupole_components(dt, orbit, component, mass=1)
mtensor = orbit2tensor(orbit, component, mass)
mtensor_ddot = d2_dt2(mtensor, dt)
return 2 * mtensor_ddot
end
function h_22_quadrupole(dt, orbit, mass=1)
h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
return h11, h12, h22
end
function h_22_strain_one_body(dt::T, orbit) where {T}
h11, h12, h22 = h_22_quadrupole(dt, orbit)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
h11 = h11_1 + h11_2
h12 = h12_1 + h12_2
h22 = h22_1 + h22_2
return h11, h12, h22
end
function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
# compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2
@assert abs(mass1 + mass2 - 1.0)<1e-12 "Masses do not sum to unity"
h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
@assert mass_ratio≤1 "mass_ratio must be <= 1"
@assert mass_ratio≥0 "mass_ratio must be non-negative"
orbit = soln2orbit(soln, model_params)
if mass_ratio > 0
m₂ = inv(T(1) + mass_ratio)
m₁ = mass_ratio * m₂
orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
return waveform
endcompute_waveform (generic function with 2 methods)Simulating the True Model
RelativisticOrbitModel defines system of odes which describes motion of point like particle in schwarzschild background, uses
where,
function RelativisticOrbitModel(u, (p, M, e), t)
χ, ϕ = u
numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
denom = sqrt((p - 2)^2 - 4 * e^2)
χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = numer / (M * (p^(3 / 2)) * denom)
return [χ̇, ϕ̇]
end
mass_ratio = 0.0 # test particle
u0 = Float64[π, 0.0] # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4) # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize) # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, eLet's simulate the true model and plot the results using OrdinaryDiffEq.jl
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)
axislegend(ax, [[l, s]], ["Waveform Data"])
fig
endDefiing a Neural Network Model
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,
const nn = Chain(Base.Fix1(broadcast, cos),
Dense(1 => 32, cos; init_weight=truncated_normal(; std=1e-4)),
Dense(32 => 32, cos; init_weight=truncated_normal(; std=1e-4)),
Dense(32 => 2; init_weight=truncated_normal(; std=1e-4)))
ps, st = Lux.setup(Xoshiro(), nn)((layer_1 = NamedTuple(), layer_2 = (weight = Float32[0.00010314012; -5.6679655f-5; 0.00012523557; 6.3329055f-5; 0.000107708685; -0.00014197409; -0.00010931962; -6.875455f-5; 7.886719f-5; 5.8389433f-5; 6.676582f-5; -3.3537435f-5; 7.7726065f-5; -8.880469f-6; -5.8863643f-5; -7.403721f-5; 0.0001627686; 5.473432f-5; 0.000109235596; 8.8782304f-5; 9.953052f-5; -0.00013791153; 0.00012123746; -0.00013399406; -2.971782f-5; 2.3602055f-5; 4.5669567f-5; -8.890721f-5; -5.6265348f-5; -7.11315f-5; -5.9060174f-5; -3.8230277f-7;;], bias = Float32[-0.17848063, -0.43544805, 0.61578155, -0.121974945, -0.08926523, -0.1373651, -0.6022072, -0.083060145, -0.32870102, 0.7044641, -0.24852312, 0.2842704, 0.9936148, 0.45286798, -0.26600242, -0.99866605, 0.6954545, 0.47927582, -0.7029295, -0.1927011, 0.3251605, 0.70349944, 0.32266355, 0.5884359, -0.6432651, 0.13041103, 0.41769862, 0.3918028, 0.74234426, -0.1797322, 0.11803484, 0.9903656]), layer_3 = (weight = Float32[-1.5965892f-5 -1.7210705f-5 -3.8751547f-5 -0.00021018849 9.094723f-5 -0.00012927961 4.002706f-5 -7.7023156f-5 -2.4063695f-5 0.000119702934 2.2148719f-5 -5.7661124f-5 -0.00021896345 -2.3988448f-5 9.944893f-5 -3.4512756f-5 -0.00011497373 -4.4570264f-5 -8.672082f-6 -4.3504253f-5 -3.7824608f-5 -2.7874104f-5 -2.8225419f-5 1.5613134f-5 -9.881785f-5 2.955567f-7 7.81068f-5 1.1865755f-5 2.754003f-5 -0.00022779842 -0.0001088007 -0.00014716652; 4.6376763f-6 2.4905867f-5 -9.4538984f-5 -3.3587618f-5 -0.00019526655 -0.00012524908 9.228252f-5 -0.00011364135 -9.8797405f-5 2.4427785f-5 -2.8746179f-5 8.6978944f-5 -0.00014750753 3.965129f-5 -0.00024433198 -7.566176f-7 5.7738325f-5 2.3012537f-5 0.00013553753 -9.64548f-5 -1.893412f-5 0.00013509524 0.00010773911 -1.2783545f-5 3.446839f-5 0.00010418781 -2.4363293f-5 7.682792f-5 2.8915052f-5 -5.6262757f-6 -0.00010059787 -0.00017617935; -9.134828f-5 0.00011359809 7.0647715f-5 -1.2142215f-5 -7.463231f-7 4.897029f-5 -5.3045776f-5 5.8051854f-5 0.0001306058 -2.6434815f-5 0.000101856305 -4.793097f-5 7.8560304f-5 -0.00013239733 -4.2749605f-5 1.790376f-6 -7.013235f-5 0.00013206122 -0.00012600822 7.025849f-5 -9.4977753f-7 5.4576925f-5 0.000111499816 5.8526286f-5 2.458901f-5 0.00023027135 -1.1367599f-5 2.2629592f-5 -2.994991f-5 -0.00012454369 0.00021838812 -7.265039f-5; 0.00011764286 4.7583358f-6 3.5762743f-5 -4.130936f-5 -7.0938084f-5 5.033712f-5 -0.00012014072 6.635741f-5 -0.00011612059 -0.00010414951 -0.00012245255 -0.00013045159 6.0349532f-5 8.3633306f-5 -0.00011422227 -0.00010539697 -0.0001255389 5.1627325f-5 9.549899f-6 1.0726479f-5 -9.197111f-5 -0.00018078026 -6.511643f-5 0.00010008196 -2.6472244f-5 -7.666523f-6 2.0875297f-5 9.365034f-6 8.3970415f-5 -0.000107663436 6.986476f-5 4.4346372f-5; -0.00012410939 -2.9317527f-5 6.459203f-5 -2.6514457f-5 -4.7089445f-5 -3.0058198f-5 -6.372351f-6 -0.00014601322 -8.079138f-5 -8.2030645f-5 6.600073f-5 -1.8408968f-5 -0.00015207544 2.841903f-6 -0.00014249896 5.9345806f-5 -0.00016871614 4.3404554f-5 0.00015660013 4.251728f-5 2.692101f-5 2.9407624f-5 -7.35451f-5 3.3391196f-5 -9.297418f-5 0.00012369597 0.00021921113 -8.464438f-6 -7.6498814f-5 5.7048284f-5 4.4241875f-5 5.474956f-5; 8.748386f-5 -0.00021836608 -0.00012509717 -9.0449284f-5 2.1433703f-5 -3.6317146f-5 -0.00010829494 -2.5891688f-5 0.00017383324 -0.00013474534 -3.6756133f-5 -0.00020680584 -4.6149697f-5 0.00013034532 2.2804981f-5 -4.9430313f-5 4.4253266f-5 -2.3470617f-5 0.00012143252 4.999627f-5 -6.4009386f-5 5.4639066f-5 -7.99664f-5 1.0080888f-5 -9.2647126f-5 6.492907f-5 -0.00019302181 8.573459f-5 9.272969f-5 -1.6120579f-5 2.3094555f-5 0.00022674321; -3.062055f-5 4.3585216f-5 0.00016767393 -5.5801676f-5 0.00013777392 -8.768119f-5 -3.1708092f-5 -2.7161524f-5 -1.2340508f-5 -0.00017624635 -9.311149f-6 0.00016886629 3.5773795f-5 3.1661246f-5 -2.391153f-6 -4.5073233f-5 -9.919286f-5 5.6658653f-5 0.00019988982 0.00015662155 -0.0001177711 3.511991f-5 -4.208404f-5 -8.191201f-5 5.9966238f-5 5.0151863f-5 4.5178687f-5 -0.0001006767 8.915395f-5 -3.5254005f-5 -9.477516f-5 2.9187173f-5; -6.2077f-5 2.098163f-5 9.750316f-6 -6.742615f-5 2.0529724f-5 4.4569937f-5 -3.0536744f-6 -0.00026915583 -6.60449f-5 0.00011678823 -0.000111600035 4.167128f-5 0.00010834713 -5.181834f-5 0.000121752615 4.3213473f-5 8.1084996f-5 -0.00011137482 -4.879556f-5 4.768408f-5 -9.626385f-5 -3.4029075f-5 -7.624476f-5 0.00013759221 -9.778741f-5 0.00010035175 -0.0001456658 8.692146f-5 -9.119344f-5 -0.00026184856 9.16497f-5 0.00018994341; 6.381998f-5 3.9355462f-5 8.567548f-5 0.00017884221 -6.336466f-5 -0.00010402022 0.0001262251 -4.3259737f-5 -4.158503f-5 -5.462502f-5 0.00014368928 -4.5967423f-5 -5.7960457f-5 1.2718655f-5 0.00020269673 0.00022850334 7.529733f-5 -4.6282523f-5 1.4572548f-5 4.923457f-5 5.2972286f-5 0.00020021427 0.00012411189 -0.00011731553 0.00019296407 0.00019032574 -3.414899f-5 1.4062983f-5 1.2843363f-5 3.5704514f-5 -9.02989f-6 -0.00013647937; 2.1442629f-5 -2.0631041f-5 -5.6966517f-5 5.239561f-5 -3.5865058f-5 3.1667798f-5 8.725442f-6 7.2831484f-5 -0.00011397634 5.3713833f-5 -0.00019974011 0.000112845 -2.0410666f-6 -5.9676793f-5 0.00015624917 -8.8956935f-5 -0.00013917788 -0.00012257548 -3.993224f-5 -0.00015079186 -6.059523f-7 -2.9233204f-5 -8.322831f-5 1.7341454f-6 9.365105f-5 1.1473153f-5 3.2392752f-5 -8.391508f-5 2.0369413f-5 7.2020965f-5 -5.739894f-5 0.000118023636; -0.00014115119 -0.00012245176 -8.724175f-5 1.4309085f-5 -0.00013554476 3.0447938f-5 1.3810028f-5 -1.530246f-5 5.3942975f-5 -2.3322085f-5 3.1710242f-5 -5.448363f-5 1.9930752f-5 -0.00010101586 -0.00015048332 8.508082f-5 -0.00016388565 -8.8217275f-5 -1.0398544f-5 9.9488294f-5 2.3797205f-5 -6.962047f-5 -0.00016338094 -0.00012735174 7.845586f-5 4.2172373f-6 -0.0002808697 -0.0001222692 -6.384491f-5 2.2507751f-5 2.9853167f-5 -8.933616f-5; -3.281884f-5 3.0386189f-5 -3.2760727f-6 -2.121636f-5 7.396355f-5 -9.672735f-5 0.000115144285 -6.433558f-5 5.1286403f-5 3.8764974f-5 5.318984f-5 1.0373758f-5 5.6000263f-5 5.8956695f-5 8.838234f-5 3.212103f-5 0.00015836021 -0.00010164053 0.00011678215 -9.372473f-5 -6.190898f-6 0.00021967638 -7.88458f-5 -0.00019236733 7.92803f-5 5.749386f-5 -1.528095f-5 -0.000147133 -5.0114846f-5 0.00017991643 -7.524297f-5 7.626485f-6; -0.00019379359 -0.000120163146 0.0001940041 0.00015563867 -0.00016226254 -3.2204745f-5 6.426498f-5 3.4671517f-5 -0.00013796998 7.1651244f-5 3.2521642f-5 4.9284306f-5 -0.00018221988 -1.3664996f-5 2.8837698f-5 -0.000107582215 -0.000101156365 -2.9570992f-6 -6.4023145f-5 0.00018925867 3.381445f-5 6.89046f-5 -6.561488f-5 3.404402f-5 2.5386173f-6 5.0749168f-5 -9.820202f-5 7.858801f-5 -0.00017341311 4.8562168f-5 -3.7471156f-5 0.0001507871; 0.0001124407 3.785758f-5 -0.00013611438 -2.4778956f-5 7.140261f-5 -8.6880136f-5 -6.956362f-5 8.155697f-5 0.000116842144 -8.9217574f-5 -0.00011169496 0.0002248874 -0.00010340819 -0.00012511283 3.675655f-5 2.2812894f-5 -1.9960862f-6 0.00020240867 9.297489f-5 -1.6309461f-5 -7.579658f-5 4.192118f-5 -4.2652635f-5 9.779666f-6 -2.1494949f-5 0.00011866331 3.436692f-5 -1.37562565f-5 1.1910219f-5 3.5260444f-5 -0.00014080032 -2.3328917f-5; 7.816452f-5 0.00015193353 -0.00022339883 -0.00013748574 1.5262161f-5 2.5801985f-5 6.254989f-5 -4.5300967f-5 2.614537f-5 -0.00012001837 -0.00012769265 -6.62968f-6 -0.00015408188 5.5824908f-6 -0.00011086571 -2.3610153f-5 -6.755855f-5 -5.9244423f-5 6.753346f-5 -1.50827445f-5 3.533241f-5 3.123487f-5 -3.9375223f-5 0.000120013785 7.20075f-5 -7.873113f-5 2.6774003f-5 1.2791468f-5 -3.7789065f-5 1.1799311f-5 7.313647f-5 -8.774539f-5; -1.6160086f-5 9.189551f-6 8.662324f-5 4.199157f-6 5.35562f-5 -0.00011431415 7.3725896f-5 -2.6671827f-5 -3.973171f-6 -5.8591624f-5 -3.3021435f-5 6.197072f-5 0.00010299457 -1.4137211f-5 -4.887353f-5 2.6419223f-6 7.085028f-5 -1.9212051f-5 2.9954542f-6 7.983822f-5 -0.00012519669 8.0017155f-5 -2.286654f-5 -0.0001107766 0.0001331133 4.30447f-5 -0.00014501497 5.949801f-5 0.00014131945 -4.484241f-6 -0.000101333615 3.4433575f-5; 0.00024877564 -8.923212f-5 -5.570185f-5 0.00014102685 -3.9369927f-5 4.6886762f-5 9.3105606f-5 -6.118643f-5 -8.9365174f-5 -3.6732927f-5 2.3888864f-5 -4.077118f-5 -0.00022242192 0.00016716069 1.2486675f-5 3.880576f-5 0.000109282446 -9.543129f-6 -1.5052161f-5 1.05525956f-7 2.945265f-5 0.000109759865 -0.00014405634 2.6152831f-5 4.3486234f-5 -1.8970552f-5 -9.984249f-5 -0.00016731954 1.9386536f-5 4.9159786f-5 0.00016875392 2.3785094f-6; -1.9602307f-5 -0.00018978966 -9.6191525f-6 0.00014125326 -0.00012960422 -9.117174f-6 5.296656f-6 -5.3671218f-5 0.00017816221 -4.6507394f-5 3.7547503f-5 8.3609244f-5 0.00029420367 -0.00013281619 -0.00018910729 0.00019782156 6.969922f-5 0.000100560756 6.889748f-5 -1.4721874f-5 7.170575f-5 -5.597594f-5 5.044422f-5 -0.00011626031 -4.235205f-5 8.659496f-5 -2.009684f-5 -4.4649107f-5 -2.4948666f-5 5.133495f-5 -5.887143f-5 -3.454314f-5; 7.14853f-5 0.00023423399 -7.890515f-5 -6.521742f-5 -0.00023374902 0.00016762153 -7.983859f-5 -4.8447164f-5 -5.4630844f-5 6.161781f-5 1.8167297f-5 0.00030005453 -3.0760486f-5 2.5025472f-5 0.00013758827 -0.00015864204 -4.909022f-5 6.071539f-5 -0.0001164657 0.0001513453 8.4891304f-5 1.0760395f-5 0.00018101874 8.938742f-5 1.9268244f-5 1.8595816f-5 -2.19858f-5 7.301032f-5 -0.000104626946 -0.00016506581 -2.2235621f-5 -1.42614645f-5; 3.8502752f-5 -0.00012596263 -5.879544f-5 1.4558035f-5 -5.873721f-5 -3.103264f-5 -6.323537f-5 0.00010674172 5.1125517f-5 -4.7154444f-5 -3.329246f-6 -7.079472f-5 2.6740372f-5 2.4078463f-6 -1.15604325f-5 -5.8803585f-6 -0.00012249623 -7.914047f-6 -0.00013310685 -1.086995f-5 -0.00021068931 5.1173778f-5 5.0224713f-5 -6.0915856f-5 2.4708621f-5 -1.1257533f-5 1.7322362f-5 3.8891347f-5 4.313546f-5 0.000109270084 7.569419f-5 -3.362006f-5; 2.9530627f-5 -8.707673f-5 -0.00013148322 -6.122963f-5 -1.5164911f-6 -0.00010785646 1.9678784f-5 7.1953546f-5 4.2391977f-5 4.860897f-5 -6.770586f-5 -4.5619236f-5 -9.565234f-5 -0.00011888708 -6.1068466f-5 -0.00018869428 -2.6916083f-5 9.0532805f-5 -1.550717f-5 -1.4515835f-5 1.2838665f-5 -7.146996f-5 -7.542834f-5 7.1035785f-5 6.3108297f-7 2.6863863f-6 5.2625393f-5 1.6737578f-5 -2.7376147f-5 -0.00013561481 7.198297f-5 -0.00013826933; -1.585893f-5 4.3976273f-5 1.0710099f-5 -9.3963165f-5 -0.00013255062 7.680121f-5 -5.7152432f-5 -7.213321f-5 -7.5061216f-5 9.501705f-5 -0.0001296371 6.320538f-5 -4.1696836f-5 4.580812f-6 -5.6625468f-5 0.0001260668 2.763255f-5 -5.8406815f-5 -7.75844f-5 4.7374095f-5 -7.999392f-5 5.788618f-5 -8.139845f-5 7.716756f-5 1.4089254f-5 3.4842353f-5 -9.913359f-5 2.1371377f-5 -0.00015684837 1.941924f-5 -5.21392f-6 -4.0193438f-5; 8.129368f-5 -3.890289f-5 -0.000115103205 -9.2349204f-5 -0.00020380106 -0.0001067453 0.00022104733 4.8936537f-5 7.9135716f-5 0.00014844353 5.064609f-7 -4.144841f-5 -2.6540287f-5 7.5159136f-5 6.527953f-5 -0.00016901706 -0.00012381702 -5.7514524f-5 -9.3007795f-5 -3.2795848f-5 -1.223915f-5 -2.120339f-5 0.00021730621 -5.562489f-5 8.709627f-5 -1.905593f-5 -2.9542553f-5 -9.0375f-5 -0.000109910965 0.00011255756 -6.162698f-5 -7.597458f-5; -1.27330695f-5 -0.00024274498 4.511968f-6 -6.804762f-5 6.2519815f-5 8.824208f-5 -7.3494703f-6 -9.250314f-5 -4.8162452f-5 1.304147f-5 -5.699506f-5 -0.00016340507 1.1906214f-5 5.7731926f-5 -3.318192f-5 4.7934238f-5 0.00022768651 9.484929f-5 -3.2059917f-5 -5.3364067f-5 0.00011895406 7.6747354f-5 -0.0002234305 -0.00014991075 2.4902688f-6 -5.400328f-6 -7.601255f-5 -1.1728393f-6 -7.101287f-5 0.00022008222 -0.00011229337 -7.520071f-5; -5.260735f-5 0.00017633644 -7.604485f-5 -0.00021786932 -0.0001384528 2.4231766f-5 -0.00020543514 0.00013332239 -0.00014098418 -0.00023364266 1.7533386f-5 0.00018820996 1.5345f-5 6.0324568f-5 8.793143f-5 7.781013f-5 -0.000103339 0.00013977586 7.463766f-5 -0.0001211126 -0.00014345362 -2.0674159f-5 0.00012353403 -4.88871f-6 -8.428823f-5 -5.2226f-5 -1.9652005f-5 2.9646122f-5 -6.404832f-5 0.00012992843 -2.7248287f-5 0.00015652827; 9.083673f-5 6.4043576f-5 -9.332609f-5 0.00014834145 0.00011703099 0.00018198563 4.116248f-5 0.00012999268 -7.877344f-5 0.00017835652 -6.288416f-5 -5.4681903f-5 5.6898614f-5 0.000264513 0.00010667178 0.00016459388 3.9367067f-5 -0.00015215744 0.000110455345 2.5331849f-5 -4.669574f-5 7.889078f-5 0.00013684688 0.00013538018 1.1821324f-5 -7.111672f-5 -4.327765f-5 -0.00018624135 0.000114589726 0.00012191407 -1.6577162f-5 0.00013921034; 7.943006f-5 0.00022015501 1.5024253f-5 -0.00017957488 -6.695886f-5 6.68885f-5 3.5427733f-5 8.000164f-5 0.00010245998 -5.8496018f-5 -0.00017049568 -5.0261538f-6 -0.00019929468 -0.00020016199 -1.4382945f-5 3.619279f-5 -9.964883f-5 -3.0767267f-5 8.050965f-5 -9.963985f-5 -8.506857f-5 0.00012634008 0.000104810344 0.00012485657 -2.0657568f-5 -1.6489117f-5 1.1680722f-5 0.00012664292 -0.0001259989 -0.000107895656 -2.9684448f-5 4.9999842f-5; 6.840174f-5 -2.9325653f-5 0.00010669003 8.63399f-5 -1.475057f-5 9.714423f-5 1.1148109f-5 -3.5168992f-5 -1.834466f-5 -3.0060557f-5 -0.00012740657 5.274187f-5 -5.7897953f-5 -0.00013586404 -1.7445363f-5 5.18444f-5 0.00016089549 -4.5460325f-5 5.304137f-5 -1.5608523f-5 -0.00014480855 0.00010683897 -0.00016285622 -8.966773f-5 -3.0275747f-5 0.00011890052 -5.1578565f-5 0.00013387084 5.9420963f-6 -4.269847f-5 7.369245f-5 0.00013470127; -0.00013494806 7.949889f-5 3.0434901f-5 -7.213977f-5 9.572458f-6 0.00012619155 5.87095f-5 -6.554937f-5 -5.7155015f-5 0.00017820379 0.00013739476 4.9053026f-5 5.944372f-6 -9.390123f-5 -1.6313532f-5 6.788144f-5 -2.6108992f-5 -1.3561013f-6 -3.1739903f-6 0.00012721044 -3.381252f-5 0.00014300265 -5.9004145f-5 -4.0572344f-5 0.00017455306 3.1070053f-5 -0.00013102032 2.0578513f-5 -1.9512347f-5 -3.3970126f-5 -0.00018503168 2.696414f-5; -1.2432046f-5 -0.00021558284 -0.0001406689 1.7390012f-5 4.955996f-5 -4.477811f-5 -2.2656084f-5 0.00010145017 -3.148564f-5 -9.798271f-5 4.8013513f-5 0.00019290381 6.200387f-5 4.5784673f-5 -0.00010381623 -4.0845658f-5 -2.9486639f-5 -0.00013232532 -4.359917f-5 0.0001691634 -0.00016063434 -0.00016074064 3.6120935f-5 -2.9035651f-5 -0.00013278266 -5.355352f-5 5.1317398f-5 -4.162753f-5 -3.7920585f-5 -5.2778385f-5 7.009943f-5 -0.00013293805; 6.489906f-5 -0.00021913063 5.7981688f-5 6.256288f-5 -6.284705f-5 0.000105562205 -0.00012198573 -6.030066f-5 -0.00022435446 9.010373f-6 0.00019542444 0.00014523522 0.0002175371 9.0326714f-5 2.2219998f-5 0.00012492875 9.412401f-5 -1.9812303f-5 6.028302f-5 -8.606082f-5 -2.886463f-5 0.000120677745 -5.7808793f-6 -1.680122f-5 0.00028505668 -6.4786655f-5 -2.7623933f-5 4.2717296f-5 -3.8724716f-5 0.00014610583 -3.715718f-5 0.00019049087; 1.7738106f-5 -0.00011536068 0.00012385854 0.00010979478 6.230253f-5 8.8979046f-5 -0.0001517834 6.839124f-5 -0.0001332028 -7.051991f-5 3.8099282f-5 -5.7893954f-5 0.00016538931 -0.00011224256 6.1903906f-5 0.00014622723 0.0001722277 -0.00013633518 5.6025394f-5 -0.000119482036 0.000108757944 2.5430396f-5 5.499886f-5 -4.9544535f-5 -5.0361003f-5 0.00018160544 5.5177206f-5 3.3300756f-5 -6.294115f-6 0.00018796677 -0.00017742004 3.0395297f-5], bias = Float32[-0.15440704, 0.12235161, 0.096408136, 0.073312126, -0.12968004, -0.08275624, -0.09084633, -0.015043917, 0.16256718, -0.12451541, -0.021880936, 0.12437385, -0.04386208, -0.16444801, -0.07131786, -0.029454999, -0.06485915, 0.0014886056, -0.02102858, -0.034277063, 0.053014476, -0.07453048, 0.12452807, -0.03442782, -0.08360266, -0.16050841, -0.15131724, -0.054177076, -0.062101547, -0.0120429145, 0.10248588, -0.13440976]), layer_4 = (weight = Float32[-4.2302778f-5 -6.9182606f-5 1.3117176f-5 -5.567428f-5 0.00010370485 -9.793875f-5 -6.373552f-5 -0.0001445409 2.5264244f-5 1.7955206f-5 9.576163f-5 -1.51561035f-5 1.9233707f-6 -7.955553f-5 0.00019175214 -0.00022399628 0.00016495283 2.1663505f-5 9.5857285f-5 2.2072456f-5 -4.3864282f-5 0.00011734621 -1.2333764f-5 -0.00017161565 -4.2276737f-5 7.363155f-5 -3.7618192f-5 5.103584f-5 6.207667f-5 -3.3743032f-5 3.1304644f-5 -6.882468f-5; -9.766876f-6 -9.211959f-5 -5.1896888f-5 -3.4824287f-5 -8.882705f-5 0.00013353988 -8.032832f-5 -6.4338026f-5 7.672168f-5 7.503211f-5 -8.202541f-5 -9.6801676f-5 0.00025304384 -1.3456911f-5 2.540327f-5 -6.1494815f-5 2.0420393f-5 -0.00012007642 7.5084936f-5 0.00010756984 1.1220228f-6 -2.2631695f-5 -2.0545807f-5 6.605765f-5 1.8904826f-5 -9.256904f-5 -0.00010431871 -0.00011025781 0.00021316396 3.5393437f-5 8.964976f-5 8.8906534f-5], bias = Float32[-0.04932511, 0.028084194])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))Similar to most DL frameworks, Lux defaults to using Float32, however, in this case we need Float64
const params = ComponentArray{Float64}(ps)
const nn_model = StatefulLuxLayer{true}(nn, nothing, st)StatefulLuxLayer{true}(
Chain(
layer_1 = WrappedFunction(Base.Fix1{typeof(broadcast), typeof(cos)}(broadcast, cos)),
layer_2 = Dense(1 => 32, cos), # 64 parameters
layer_3 = Dense(32 => 32, cos), # 1_056 parameters
layer_4 = Dense(32 => 2), # 66 parameters
),
) # Total: 1_186 parameters,
# plus 0 states.Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses
where,
function ODE_model(u, nn_params, t)
χ, ϕ = u
p, M, e = ode_model_params
# In this example we know that `st` is am empty NamedTuple hence we can safely ignore
# it, however, in general, we should use `st` to store the state of the neural network.
y = 1 .+ nn_model([first(u)], nn_params)
numer = (1 + e * cos(χ))^2
denom = M * (p^(3 / 2))
χ̇ = (numer / denom) * y[1]
ϕ̇ = (numer / denom) * y[2]
return [χ̇, ϕ̇]
endODE_model (generic function with 1 method)Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)
axislegend(ax, [[l1, s1], [l2, s2]],
["Waveform Data", "Waveform Neural Net (Untrained)"]; position=:lb)
fig
endSetting Up for Training the Neural Network
Next, we define the objective (loss) function to be minimized when training the neural differential equations.
const mseloss = MSELoss()
function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
return mseloss(waveform, pred_waveform), pred_waveform
endloss (generic function with 1 method)Warmup the loss function
loss(params)(0.0009368408762704554, [-0.02440944388341271, -0.023616489435972096, -0.022823534988531676, -0.021486489955549727, -0.019581795478404763, -0.01707564119888454, -0.013924469038705777, -0.010072996228995268, -0.005457380492528897, -6.119097106432924e-6, 0.006350404837036709, 0.013661540372288365, 0.021918748404465494, 0.03096204223483149, 0.04029604689114626, 0.04868173298442405, 0.05330631098213167, 0.048282427330589045, 0.023201923458000297, -0.032033995092922596, -0.09977682152826582, -0.1093413840955314, -0.03128711475881844, 0.049734190319610315, 0.07750900930082848, 0.068678218561442, 0.04828219137812768, 0.02805671384021346, 0.011447321129158765, -0.0012136470378977708, -0.010484382493560339, -0.017046453292996974, -0.021494448217397856, -0.024297085226691127, -0.025810869848547007, -0.026302977222856586, -0.02597477910044069, -0.024977271528494157, -0.023424641325893326, -0.021403736526401654, -0.018980933757551877, -0.01620703876770415, -0.013123821127009404, -0.00976066403792272, -0.006145803181584753, -0.0023026218662267396, 0.0017454270157772527, 0.0059718445242329925, 0.010343180230739976, 0.01481313584292001, 0.019313484377408353, 0.023741122679841232, 0.02793540640844979, 0.0316437014227903, 0.03446225403100719, 0.03573976170315911, 0.034417213782986296, 0.028785312988323652, 0.016181115413849374, -0.007014964634304234, -0.04347635898807396, -0.08627164999335722, -0.10095662983467396, -0.03998026235630907, 0.06064686314081488, 0.1018034768949137, 0.07682799461140735, 0.03518639686489841, 0.002472507934496113, -0.017916964084418174, -0.028960755280022323, -0.03387041516795149, -0.034952731593972844, -0.03368324395302625, -0.030991402587490453, -0.027456235371120907, -0.023440671232015736, -0.019180996670750412, -0.014822627563983658, -0.01047091668374502, -0.006185365453700505, -0.002020805612671782, 0.00199662476475258, 0.005831075121612848, 0.009460693179009974, 0.01286230077342461, 0.01600050086564285, 0.018847476125215503, 0.021354324000716157, 0.023470359351562907, 0.02512080117529018, 0.02621333737755481, 0.026623709636465185, 0.026184855587155918, 0.024678753700339902, 0.02180235222203088, 0.017163152619767853, 0.010223657446820395, 0.00031922754873195907, -0.01332350484779216, -0.03126359626391412, -0.05285356848934053, -0.07326471436372295, -0.07722979031312893, -0.036541989785470914, 0.05359049257289716, 0.11568556160386173, 0.0857835918277878, 0.017880902587307918, -0.02934475716909624, -0.04895223166717435, -0.051769838916456465, -0.04668585698304109, -0.03854224397010213, -0.029657284104526338, -0.021062129758512693, -0.013187579948565559, -0.00616833500260983, -2.6833119051290577e-5, 0.005284548496029137, 0.009814219809519976, 0.013630745636644868, 0.01678010559813443, 0.019316125682873445, 0.021269234841022843, 0.02267210899860363, 0.023539072617160466, 0.023881431417736357, 0.02369818547635291, 0.022979201190971413, 0.0217076338914795, 0.019850568507231322, 0.01737287076284905, 0.014216562416235465, 0.010325019978417217, 0.00561604112671312, 7.486201130952882e-6, -0.006575786987658198, -0.01421022913854563, -0.02288092275044315, -0.032409547966104776, -0.042206114169078414, -0.050781345999671865, -0.054691278670557834, -0.046775554190919326, -0.015108859001679102, 0.04796969733211896, 0.11076561250118235, 0.09713449067207948, 0.009644677782546783, -0.059201021056787116, -0.07614567227663142, -0.06407142267814818, -0.04413319838324986, -0.02523529808451263, -0.00983026260323138, 0.0019638608690197216, 0.010667835120622, 0.01690227926627236, 0.021182272310571577, 0.02392684404353409, 0.025449254895208735, 0.025998446664916207, 0.025751014150423263, 0.02484591123890786, 0.023390007367642324, 0.02146095408173984, 0.019119694962490483, 0.016419400194740155, 0.01338816526474995, 0.010060668031284538, 0.006457480541205496, 0.002608924969042172, -0.0014769423466576006, -0.005766850705140525, -0.010232872188295722, -0.014821675133400351, -0.019475649041126832, -0.02407909094109603, -0.028462648375076116, -0.032349156237476646, -0.03529334592834966, -0.03655647577873159, -0.034922401672589753, -0.02841692470872009, -0.013946960779713006, 0.012490411980715661, 0.052821194591173494, 0.09499539066494461, 0.09539847852623917, 0.016347127079694414, -0.07692842042944117, -0.09892273717410571, -0.06736214789024005, -0.02795729091284723, 0.0012707678293160138, 0.01924847077450328, 0.028983107870034178, 0.03330190121409156, 0.034197146392721395, 0.03295352484546554, 0.030389658627970117, 0.027019976072820143, 0.023174874490954886, 0.01907479982516995, 0.01485324243947096, 0.010610085615814275, 0.006407620528158245, 0.0023014737982454143, -0.001685782600174136, -0.0055159409195073535, -0.009165910321046222, -0.012600315475124278, -0.015799868847112875, -0.018719474123544747, -0.02131556782709798, -0.02352591395171502, -0.025278564698641457, -0.026468249703109754, -0.026959971741564368, -0.02657637759257369, -0.02506828113526966, -0.0221067349264013, -0.017240786910315647, -0.00988087306842885, 0.0007499614474046625, 0.015494536405193042, 0.034905016008399736, 0.05783431781807213, 0.07759009093695585, 0.07465767660274908, 0.019545407107452166, -0.07463497982753656, -0.11592560218415823, -0.07037538380004112, -0.005803700656226078, 0.033853773523926926, 0.04900973186413346, 0.050158601901279035, 0.04480362375889889, 0.03693826950003747, 0.028479467631057607, 0.020303321999571255, 0.012775007545204418, 0.0060334196574562135, 8.2993928231055e-5, -0.005095733865381096, -0.009553578707345339, -0.013331561628773763, -0.01710954455020246])Now let us define a callback function to store the loss over time
const losses = Float64[]
function callback(θ, l, pred_waveform)
push!(losses, l)
@printf "Training %10s Iteration: %5d %10s Loss: %.10f\n" "" length(losses) "" l
return false
endcallback (generic function with 1 method)Training the Neural Network
Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
optprob, BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
callback, maxiters=1000)retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [0.00010314011888112873; -5.667965524480678e-5; 0.00012523557234089822; 6.332905468298122e-5; 0.00010770868539111689; -0.00014197408745530993; -0.00010931961878668517; -6.875454710097983e-5; 7.886718958616257e-5; 5.8389432524563745e-5; 6.676581688225269e-5; -3.3537435228936374e-5; 7.772606477374211e-5; -8.880469067662489e-6; -5.886364306206815e-5; -7.403721247101203e-5; 0.00016276859969366342; 5.473431883729063e-5; 0.00010923559602815658; 8.878230437403545e-5; 9.953052358468994e-5; -0.00013791152741760015; 0.00012123746273573488; -0.00013399406452663243; -2.9717819415964186e-5; 2.360205508011859e-5; 4.566956704366021e-5; -8.890721073839813e-5; -5.626534766633995e-5; -7.113150059012696e-5; -5.906017395318486e-5; -3.8230277255024703e-7;;], bias = [-0.1784806251525879, -0.4354480504989624, 0.6157815456390381, -0.12197494506835938, -0.08926522731781006, -0.13736510276794434, -0.6022071838378906, -0.08306014537811279, -0.3287010192871094, 0.7044640779495239, -0.24852311611175537, 0.28427040576934814, 0.9936147928237915, 0.4528679847717285, -0.2660024166107178, -0.9986660480499268, 0.695454478263855, 0.47927582263946533, -0.7029294967651367, -0.19270110130310059, 0.32516050338745117, 0.703499436378479, 0.3226635456085205, 0.5884358882904053, -0.6432651281356812, 0.1304110288619995, 0.41769862174987793, 0.3918027877807617, 0.7423442602157593, -0.17973220348358154, 0.11803483963012695, 0.9903656244277954]), layer_3 = (weight = [-1.5965892089297995e-5 -1.7210704754688777e-5 -3.875154652632773e-5 -0.00021018848929088563 9.094722918234766e-5 -0.00012927960779052228 4.0027058275882155e-5 -7.702315633650869e-5 -2.406369458185509e-5 0.00011970293417107314 2.214871892647352e-5 -5.766112371929921e-5 -0.00021896345424465835 -2.3988448447198607e-5 9.944893099600449e-5 -3.451275551924482e-5 -0.00011497372906887904 -4.457026443560608e-5 -8.672082003613468e-6 -4.350425297161564e-5 -3.7824607716174796e-5 -2.787410448945593e-5 -2.8225418645888567e-5 1.5613133655278943e-5 -9.881785081233829e-5 2.955567026674544e-7 7.81068010837771e-5 1.1865755368489772e-5 2.7540030714590102e-5 -0.00022779841674491763 -0.00010880069748964161 -0.00014716651639901102; 4.6376762838917784e-6 2.490586666681338e-5 -9.453898383071646e-5 -3.35876175086014e-5 -0.00019526654796209186 -0.00012524907651823014 9.228252019966021e-5 -0.00011364134843461215 -9.879740537144244e-5 2.442778531985823e-5 -2.8746178941219114e-5 8.697894372744486e-5 -0.00014750752598047256 3.965129144489765e-5 -0.000244331982685253 -7.56617623665079e-7 5.7738325267564505e-5 2.301253698533401e-5 0.0001355375279672444 -9.645480167819187e-5 -1.8934120816993527e-5 0.0001350952370557934 0.0001077391134458594 -1.278354466194287e-5 3.4468390367692336e-5 0.00010418781312182546 -2.4363293050555512e-5 7.682792056584731e-5 2.8915052098454908e-5 -5.62627565159346e-6 -0.00010059787018690258 -0.00017617935372982174; -9.134827996604145e-5 0.00011359809286659583 7.064771489240229e-5 -1.2142215382482391e-5 -7.463231099791301e-7 4.897029066341929e-5 -5.304577643983066e-5 5.805185355711728e-5 0.0001306057965848595 -2.6434814571985044e-5 0.00010185630526393652 -4.793097104993649e-5 7.856030424591154e-5 -0.00013239732652436942 -4.2749605199787766e-5 1.7903760181070538e-6 -7.013235153863207e-5 0.00013206122093833983 -0.00012600822083186358 7.025848753983155e-5 -9.49777529513085e-7 5.4576925322180614e-5 0.00011149981583002955 5.852628601132892e-5 2.4589009626652114e-5 0.00023027135466691107 -1.1367598744982388e-5 2.2629592422163114e-5 -2.9949909730930813e-5 -0.00012454368697945029 0.0002183881151722744 -7.265039312187582e-5; 0.00011764286318793893 4.758335762744537e-6 3.576274320948869e-5 -4.130935849389061e-5 -7.093808380886912e-5 5.0337119319010526e-5 -0.00012014072126476094 6.635741010541096e-5 -0.00011612058733589947 -0.0001041495124809444 -0.0001224525476573035 -0.0001304515899391845 6.0349531850079075e-5 8.363330562133342e-5 -0.00011422226816648617 -0.00010539696813793853 -0.00012553890701383352 5.1627324864966795e-5 9.549899004923645e-6 1.0726478649303317e-5 -9.197110921377316e-5 -0.00018078026187140495 -6.511643005069345e-5 0.00010008196113631129 -2.6472243916941807e-5 -7.666522833460476e-6 2.087529719574377e-5 9.365034202346578e-6 8.397041528951377e-5 -0.00010766343621071428 6.986475636949763e-5 4.434637230588123e-5; -0.00012410938506945968 -2.9317527150851674e-5 6.459202995756641e-5 -2.6514457204029895e-5 -4.7089444706216455e-5 -3.0058197808102705e-5 -6.372350981109776e-6 -0.00014601321890950203 -8.07913820608519e-5 -8.203064498957247e-5 6.600072811124846e-5 -1.8408967662253417e-5 -0.00015207544493023306 2.8419030968507286e-6 -0.0001424989604856819 5.934580622124486e-5 -0.0001687161420704797 4.3404554162407294e-5 0.00015660013013985008 4.251728023518808e-5 2.6921010430669412e-5 2.9407623514998704e-5 -7.354510307777673e-5 3.33911957568489e-5 -9.29741800064221e-5 0.00012369596515782177 0.00021921112784184515 -8.464437996735796e-6 -7.649881445104256e-5 5.704828436137177e-5 4.424187500262633e-5 5.474955833051354e-5; 8.748385880608112e-5 -0.00021836608357261866 -0.00012509716907516122 -9.044928447110578e-5 2.14337032957701e-5 -3.631714571383782e-5 -0.0001082949383999221 -2.5891687982948497e-5 0.00017383323574904352 -0.00013474533625412732 -3.675613334053196e-5 -0.00020680583838839084 -4.614969657268375e-5 0.00013034531730227172 2.280498119944241e-5 -4.943031308357604e-5 4.4253265514271334e-5 -2.347061672480777e-5 0.00012143251660745591 4.999626980861649e-5 -6.400938582373783e-5 5.463906563818455e-5 -7.996639760676771e-5 1.0080888387165032e-5 -9.264712571166456e-5 6.492906686617061e-5 -0.00019302181317470968 8.573458762839437e-5 9.272969327867031e-5 -1.6120578948175535e-5 2.309455521753989e-5 0.00022674321371596307; -3.062055111513473e-5 4.358521618996747e-5 0.00016767393390182406 -5.580167635343969e-5 0.00013777392450720072 -8.768119005253538e-5 -3.1708092137705535e-5 -2.716152448556386e-5 -1.2340507964836434e-5 -0.00017624635074753314 -9.311149369750638e-6 0.00016886628873180598 3.5773795389104635e-5 3.1661245884606615e-5 -2.3911529751785565e-6 -4.507323319558054e-5 -9.919286094373092e-5 5.665865319315344e-5 0.00019988982239738107 0.00015662155055906624 -0.00011777110194088891 3.511990871629678e-5 -4.2084040615009144e-5 -8.19120105006732e-5 5.99662380409427e-5 5.0151862524216995e-5 4.517868728726171e-5 -0.00010067669791169465 8.915395301301032e-5 -3.525400461512618e-5 -9.477516141487285e-5 2.9187172913225368e-5; -6.207700062077492e-5 2.098162985930685e-5 9.75031616690103e-6 -6.742614641552791e-5 2.0529723769868724e-5 4.456993701751344e-5 -3.053674390685046e-6 -0.0002691558329388499 -6.604490044992417e-5 0.00011678822920657694 -0.00011160003487020731 4.1671279177535325e-5 0.00010834712884388864 -5.181834058021195e-5 0.0001217526150867343 4.321347296354361e-5 8.108499605441466e-5 -0.00011137482215417549 -4.8795558541314676e-5 4.768407961819321e-5 -9.626385144656524e-5 -3.4029075322905555e-5 -7.624475983902812e-5 0.0001375922147417441 -9.77874078671448e-5 0.0001003517463686876 -0.00014566580648534 8.69214563863352e-5 -9.119344031205401e-5 -0.00026184855960309505 9.164970106212422e-5 0.00018994341371580958; 6.381997809512541e-5 3.9355461922241375e-5 8.567547774873674e-5 0.0001788422086974606 -6.336466321954504e-5 -0.00010402021871414036 0.00012622510257642716 -4.3259737140033394e-5 -4.1585029975976795e-5 -5.4625019402010366e-5 0.00014368927804753184 -4.5967422920512035e-5 -5.796045661554672e-5 1.2718654943455476e-5 0.00020269672677386552 0.00022850334062241018 7.529732829425484e-5 -4.628252281690948e-5 1.4572548025171272e-5 4.923457163386047e-5 5.29722856299486e-5 0.00020021427189931273 0.00012411188799887896 -0.00011731553240679204 0.00019296407117508352 0.00019032573618460447 -3.414899038034491e-5 1.4062982700124849e-5 1.2843363037973177e-5 3.570451372070238e-5 -9.029889952216763e-6 -0.00013647937157656997; 2.14426290767733e-5 -2.0631041479646228e-5 -5.696651714970358e-5 5.239561141934246e-5 -3.586505772545934e-5 3.166779788443819e-5 8.725442057766486e-6 7.283148443093523e-5 -0.00011397634079912677 5.371383304009214e-5 -0.00019974011229351163 0.00011284500214969739 -2.041066636593314e-6 -5.967679317109287e-5 0.00015624916704837233 -8.895693463273346e-5 -0.00013917787873651832 -0.00012257548223715276 -3.9932241634232923e-5 -0.0001507918641436845 -6.059523229851038e-7 -2.92332042590715e-5 -8.322831126861274e-5 1.7341453713015653e-6 9.365104779135436e-5 1.147315288108075e-5 3.239275247324258e-5 -8.39150816318579e-5 2.036941259575542e-5 7.202096458058804e-5 -5.739893822465092e-5 0.00011802363587776199; -0.000141151191201061 -0.00012245176185388118 -8.724175131646916e-5 1.4309084690466989e-5 -0.00013554476026911288 3.0447938115685247e-5 1.381002766720485e-5 -1.530245936010033e-5 5.394297477323562e-5 -2.332208532607183e-5 3.1710242183180526e-5 -5.448362935567275e-5 1.9930752387153916e-5 -0.00010101585939992219 -0.00015048331988509744 8.508082100888714e-5 -0.00016388564836233854 -8.82172753335908e-5 -1.0398543963674456e-5 9.948829392669722e-5 2.3797205358278006e-5 -6.962047336855903e-5 -0.0001633809442864731 -0.00012735174095723778 7.845585787435994e-5 4.217237346892944e-6 -0.0002808696881402284 -0.0001222691935254261 -6.384491280186921e-5 2.2507751054945402e-5 2.9853166779503226e-5 -8.93361575435847e-5; -3.2818839827086776e-5 3.0386188882403076e-5 -3.2760726753622293e-6 -2.121635952789802e-5 7.396355067612603e-5 -9.672735177446157e-5 0.00011514428479131311 -6.43355815554969e-5 5.1286402594996616e-5 3.8764974306104705e-5 5.318984040059149e-5 1.037375841406174e-5 5.6000262702582404e-5 5.895669528399594e-5 8.838233770802617e-5 3.2121028198162094e-5 0.0001583602133905515 -0.0001016405294649303 0.00011678214650601149 -9.37247314141132e-5 -6.19089814790641e-6 0.0002196763816755265 -7.884579827077687e-5 -0.00019236732623539865 7.928029663162306e-5 5.7493860367685556e-5 -1.52809498104034e-5 -0.0001471330033382401 -5.0114846089854836e-5 0.00017991643107961863 -7.524296961491928e-5 7.626485057699028e-6; -0.00019379358855076134 -0.00012016314576612785 0.0001940040965564549 0.00015563867054879665 -0.00016226254228968173 -3.220474536647089e-5 6.426498293876648e-5 3.46715169143863e-5 -0.0001379699824610725 7.165124407038093e-5 3.252164242439903e-5 4.928430644213222e-5 -0.00018221988284494728 -1.3664996004081331e-5 2.8837697755079716e-5 -0.00010758221469586715 -0.00010115636541740969 -2.957099241029937e-6 -6.402314465958625e-5 0.00018925867334473878 3.381444912520237e-5 6.890459917485714e-5 -6.56148768030107e-5 3.404402013984509e-5 2.538617309255642e-6 5.0749167712638155e-5 -9.820202103583142e-5 7.858801109250635e-5 -0.00017341310740448534 4.8562167648924515e-5 -3.7471156247192994e-5 0.00015078710566740483; 0.00011244069901295006 3.785757871810347e-5 -0.00013611438043881208 -2.4778955776127987e-5 7.140261004678905e-5 -8.688013622304425e-5 -6.95636190357618e-5 8.155697287293151e-5 0.00011684214405249804 -8.921757398638874e-5 -0.00011169495701324195 0.0002248873934149742 -0.00010340818698750809 -0.00012511282693594694 3.6756551708094776e-5 2.2812893803347833e-5 -1.996086211875081e-6 0.00020240867161192 9.297489305026829e-5 -1.6309460988850333e-5 -7.579658267786726e-5 4.192117921775207e-5 -4.2652634874684736e-5 9.779665560927242e-6 -2.149494866898749e-5 0.00011866330896737054 3.4366919862804934e-5 -1.375625652144663e-5 1.1910218745470047e-5 3.526044383761473e-5 -0.00014080031542107463 -2.3328917450271547e-5; 7.816452125553042e-5 0.00015193353465292603 -0.00022339883435051888 -0.00013748573837801814 1.5262161468854174e-5 2.5801984520512633e-5 6.254988693399355e-5 -4.5300967030925676e-5 2.6145369702135213e-5 -0.00012001836876152083 -0.00012769264867529273 -6.629679774050601e-6 -0.0001540818775538355 5.5824907576607075e-6 -0.00011086570884799585 -2.3610153220943175e-5 -6.75585470162332e-5 -5.9244423027848825e-5 6.75334595143795e-5 -1.5082744539540727e-5 3.533241033437662e-5 3.123487113043666e-5 -3.93752234231215e-5 0.0001200137849082239 7.200749678304419e-5 -7.873112917877734e-5 2.6774003345053643e-5 1.2791468179784715e-5 -3.778906466322951e-5 1.1799311323557049e-5 7.313647074624896e-5 -8.774539310252294e-5; -1.6160085579031147e-5 9.189550837618299e-6 8.662323671160266e-5 4.199157046969049e-6 5.3556199418380857e-5 -0.00011431414895923808 7.372589607257396e-5 -2.6671827072277665e-5 -3.973170805693371e-6 -5.859162411070429e-5 -3.302143522887491e-5 6.197072070790455e-5 0.0001029945706250146 -1.4137211110210046e-5 -4.887353134108707e-5 2.64192226495652e-6 7.085027755238116e-5 -1.9212051483918913e-5 2.9954542242194293e-6 7.983821706147864e-5 -0.00012519668962340802 8.001715468708426e-5 -2.2866539438837208e-5 -0.00011077660019509494 0.0001331132953055203 4.304469985072501e-5 -0.0001450149720767513 5.949800834059715e-5 0.00014131945499684662 -4.484240889723878e-6 -0.00010133361502084881 3.443357491050847e-5; 0.0002487756428308785 -8.92321186256595e-5 -5.570185021497309e-5 0.00014102684508543462 -3.936992652597837e-5 4.6886761992936954e-5 9.31056056288071e-5 -6.118643068475649e-5 -8.936517406255007e-5 -3.673292667372152e-5 2.3888864234322682e-5 -4.0771181375021115e-5 -0.00022242192062549293 0.00016716068785171956 1.2486674677347764e-5 3.880575968651101e-5 0.0001092824459192343 -9.543128726363648e-6 -1.5052160961204208e-5 1.0552595597346226e-7 2.945265077869408e-5 0.00010975986515404657 -0.00014405633555725217 2.6152831196668558e-5 4.348623406258412e-5 -1.8970551536767744e-5 -9.984248754335567e-5 -0.0001673195365583524 1.9386536223464645e-5 4.9159785703523085e-5 0.00016875391884241253 2.3785094072081847e-6; -1.960230656550266e-5 -0.00018978965817950666 -9.619152478990145e-6 0.00014125325833447278 -0.00012960421736352146 -9.117174158745911e-6 5.296656127029564e-6 -5.367121775634587e-5 0.00017816221225075424 -4.650739356293343e-5 3.754750287043862e-5 8.360924402950332e-5 0.00029420366627164185 -0.00013281618885230273 -0.00018910728977061808 0.0001978215586859733 6.969922105781734e-5 0.00010056075552711263 6.889748328831047e-5 -1.4721874322276562e-5 7.170574826886877e-5 -5.597593917627819e-5 5.0444221415091306e-5 -0.00011626030754996464 -4.235205051372759e-5 8.659496234031394e-5 -2.009684067161288e-5 -4.464910671231337e-5 -2.494866566848941e-5 5.1334951422177255e-5 -5.8871428336715326e-5 -3.454313991824165e-5; 7.148530130507424e-5 0.00023423398670274764 -7.890514825703576e-5 -6.521742034237832e-5 -0.00023374901502393186 0.0001676215324550867 -7.983858813531697e-5 -4.844716386287473e-5 -5.463084380608052e-5 6.16178076597862e-5 1.8167296730098315e-5 0.00030005452572368085 -3.0760485969949514e-5 2.5025472496054135e-5 0.00013758827117271721 -0.0001586420403327793 -4.909022027277388e-5 6.071538882679306e-5 -0.00011646570055745542 0.00015134530258364975 8.48913041409105e-5 1.0760394616227131e-5 0.00018101873865816742 8.938741666497663e-5 1.926824370457325e-5 1.859581607277505e-5 -2.1985799321555533e-5 7.301032019313425e-5 -0.00010462694626767188 -0.0001650658086873591 -2.2235621145227924e-5 -1.426146445737686e-5; 3.850275243166834e-5 -0.0001259626296814531 -5.8795438235392794e-5 1.4558035218215082e-5 -5.8737208746606484e-5 -3.103263952652924e-5 -6.3235369452741e-5 0.00010674171790014952 5.112551662023179e-5 -4.715444447356276e-5 -3.329246055727708e-6 -7.079471834003925e-5 2.6740372049971484e-5 2.4078462956822477e-6 -1.156043254013639e-5 -5.880358457943657e-6 -0.00012249623250681907 -7.91404727351619e-6 -0.00013310684880707413 -1.0869949619518593e-5 -0.0002106893080053851 5.117377804708667e-5 5.0224713049829006e-5 -6.091585601097904e-5 2.470862091286108e-5 -1.125753260566853e-5 1.7322361600236036e-5 3.889134677592665e-5 4.3135460146004334e-5 0.00010927008406724781 7.569418812636286e-5 -3.362006100360304e-5; 2.9530627216445282e-5 -8.707673259777948e-5 -0.0001314832188654691 -6.12296280451119e-5 -1.5164911246756674e-6 -0.00010785646009026095 1.967878415598534e-5 7.195354555733502e-5 4.239197733113542e-5 4.860897024627775e-5 -6.770586333004758e-5 -4.5619235606864095e-5 -9.56523435888812e-5 -0.00011888708104379475 -6.106846558395773e-5 -0.00018869427731260657 -2.6916082788375206e-5 9.0532805188559e-5 -1.5507170246564783e-5 -1.4515834664052818e-5 1.2838664588343818e-5 -7.146996358642355e-5 -7.542833918705583e-5 7.103578536771238e-5 6.310829689937236e-7 2.686386324057821e-6 5.2625393436755985e-5 1.6737578334868886e-5 -2.7376147045288235e-5 -0.00013561481318902224 7.198296952992678e-5 -0.00013826932990923524; -1.5858930055401288e-5 4.397627344587818e-5 1.0710098649724387e-5 -9.396316454512998e-5 -0.00013255061639938504 7.680121052544564e-5 -5.715243241866119e-5 -7.213321077870205e-5 -7.506121619371697e-5 9.501705062575638e-5 -0.00012963710469193757 6.320537795545533e-5 -4.1696835978655145e-5 4.58081194665283e-6 -5.662546755047515e-5 0.000126066806842573 2.7632549972622655e-5 -5.840681478730403e-5 -7.758440187899396e-5 4.737409471999854e-5 -7.999392255442217e-5 5.788618000224233e-5 -8.13984515843913e-5 7.716756226727739e-5 1.4089254364080261e-5 3.484235276118852e-5 -9.913359099300578e-5 2.1371377442847006e-5 -0.0001568483712617308 1.9419239833950996e-5 -5.213919848756632e-6 -4.019343759864569e-5; 8.129367779474705e-5 -3.890289008268155e-5 -0.00011510320473462343 -9.23492043511942e-5 -0.00020380105706863105 -0.0001067452976712957 0.00022104733216110617 4.8936537496047094e-5 7.913571607787162e-5 0.00014844353427179158 5.064608785687597e-7 -4.144840931985527e-5 -2.6540286853560247e-5 7.515913603128865e-5 6.527952791657299e-5 -0.0001690170611254871 -0.00012381702254060656 -5.751452408730984e-5 -9.300779493059963e-5 -3.279584780102596e-5 -1.2239150237292051e-5 -2.120339013345074e-5 0.0002173062093788758 -5.562488877330907e-5 8.709626854397357e-5 -1.9055929442401975e-5 -2.954255251097493e-5 -9.037499694386497e-5 -0.00010991096496582031 0.00011255755816819146 -6.162698264233768e-5 -7.597458170494065e-5; -1.2733069524983875e-5 -0.0002427449799142778 4.511968199949479e-6 -6.804762233514339e-5 6.251981540117413e-5 8.824207907309756e-5 -7.349470251938328e-6 -9.250314178643748e-5 -4.816245200345293e-5 1.3041470083408058e-5 -5.699506073142402e-5 -0.00016340507136192173 1.190621424029814e-5 5.773192606284283e-5 -3.318192102597095e-5 4.793423795490526e-5 0.0002276865125168115 9.484928887104616e-5 -3.205991743016057e-5 -5.336406684364192e-5 0.00011895406350959092 7.674735388718545e-5 -0.0002234304993180558 -0.0001499107456766069 2.4902687982830685e-6 -5.400328063842608e-6 -7.601254765177146e-5 -1.1728393474186305e-6 -7.101287337718531e-5 0.00022008222003933042 -0.00011229336814722046 -7.52007108530961e-5; -5.260734906187281e-5 0.00017633644165471196 -7.604485290357843e-5 -0.0002178693248424679 -0.0001384528004564345 2.423176556476392e-5 -0.00020543513528537005 0.0001333223917754367 -0.00014098417886998504 -0.00023364265507552773 1.7533386198920198e-5 0.00018820996046997607 1.5345000065281056e-5 6.032456803950481e-5 8.793143206276e-5 7.781013118801638e-5 -0.00010333899990655482 0.0001397758605889976 7.463766087312251e-5 -0.00012111260002711788 -0.00014345362433232367 -2.067415880446788e-5 0.0001235340314451605 -4.888710009254282e-6 -8.42882291181013e-5 -5.222600157139823e-5 -1.965200499398634e-5 2.9646122129634023e-5 -6.404831947293133e-5 0.00012992843403480947 -2.724828664213419e-5 0.0001565282727824524; 9.083672921406105e-5 6.404357554856688e-5 -9.33260889723897e-5 0.00014834145258646458 0.00011703099153237417 0.00018198562611360103 4.116247873753309e-5 0.0001299926807405427 -7.877343887230381e-5 0.00017835652397479862 -6.288415897870436e-5 -5.468190283863805e-5 5.689861427526921e-5 0.00026451298617757857 0.00010667178139556199 0.00016459387552458793 3.9367067074636e-5 -0.0001521574449725449 0.0001104553448385559 2.533184851927217e-5 -4.669573900173418e-5 7.889077824074775e-5 0.00013684688019566238 0.00013538017810788006 1.182132382382406e-5 -7.111672312021255e-5 -4.3277650547679514e-5 -0.00018624134827405214 0.00011458972585387528 0.00012191406858619303 -1.6577161659370176e-5 0.0001392103440593928; 7.94300576671958e-5 0.00022015500871930271 1.5024253116280306e-5 -0.00017957488307729363 -6.695886258967221e-5 6.688849680358544e-5 3.5427732655080035e-5 8.000164234545082e-5 0.00010245997691527009 -5.849601802765392e-5 -0.00017049568123184144 -5.026153758080909e-6 -0.00019929467816837132 -0.0002001619868678972 -1.4382944755197968e-5 3.6192788684275e-5 -9.964883065549657e-5 -3.076726716244593e-5 8.05096497060731e-5 -9.963985212380067e-5 -8.506856829626486e-5 0.0001263400772586465 0.00010481034405529499 0.0001248565677087754 -2.0657567802118137e-5 -1.6489117115270346e-5 1.1680722309392877e-5 0.00012664291716646403 -0.0001259988930542022 -0.00010789565567392856 -2.9684448236366734e-5 4.999984230380505e-5; 6.84017431922257e-5 -2.9325652576517314e-5 0.00010669002949725837 8.633989637019113e-5 -1.47505697896122e-5 9.714422776596621e-5 1.1148108569614124e-5 -3.516899232636206e-5 -1.8344659110880457e-5 -3.0060557037359104e-5 -0.00012740657257381827 5.274187060422264e-5 -5.7897952501662076e-5 -0.00013586404384113848 -1.744536348269321e-5 5.184439942240715e-5 0.00016089549171738327 -4.546032505459152e-5 5.304137084749527e-5 -1.5608522517140955e-5 -0.0001448085531592369 0.0001068389683496207 -0.00016285621677525342 -8.966773020802066e-5 -3.0275747121777385e-5 0.00011890051973750815 -5.1578565035015345e-5 0.00013387083890847862 5.9420963225420564e-6 -4.2698469769675285e-5 7.369244849542156e-5 0.00013470127305481583; -0.0001349480589851737 7.949888822622597e-5 3.0434901418630034e-5 -7.213976641651243e-5 9.572458111506421e-6 0.00012619154585991055 5.8709498262032866e-5 -6.554937135661021e-5 -5.715501538361423e-5 0.00017820378707256168 0.0001373947598040104 4.905302557745017e-5 5.944371878285892e-6 -9.390123159391806e-5 -1.631353188713547e-5 6.788143946323544e-5 -2.61089917330537e-5 -1.3561012792706606e-6 -3.1739903079142096e-6 0.00012721044186037034 -3.3812520996434614e-5 0.00014300264592748135 -5.900414544157684e-5 -4.057234400534071e-5 0.00017455306078772992 3.107005250058137e-5 -0.00013102032244205475 2.057851270365063e-5 -1.9512346625560895e-5 -3.397012551431544e-5 -0.00018503167666494846 2.6964140488416888e-5; -1.2432045878085773e-5 -0.00021558284061029553 -0.00014066889707464725 1.739001163514331e-5 4.9559959734324366e-5 -4.477810944081284e-5 -2.2656084183836356e-5 0.00010145016858587041 -3.14856406475883e-5 -9.79827091214247e-5 4.801351315109059e-5 0.00019290381169412285 6.200386997079477e-5 4.578467269311659e-5 -0.00010381622996646911 -4.0845658077159896e-5 -2.9486638595699333e-5 -0.00013232532364781946 -4.359917147667147e-5 0.00016916339518502355 -0.000160634343046695 -0.00016074064478743821 3.612093496485613e-5 -2.9035651095909998e-5 -0.00013278266123961657 -5.355351822800003e-5 5.131739817443304e-5 -4.162752884440124e-5 -3.792058487306349e-5 -5.277838499750942e-5 7.009942783042789e-5 -0.00013293804659042507; 6.489906081696972e-5 -0.00021913062664680183 5.7981687859864905e-5 6.256288179429248e-5 -6.284705159487203e-5 0.00010556220513535663 -0.00012198572949273512 -6.0300659242784604e-5 -0.00022435445862356573 9.010373105411418e-6 0.00019542443624231964 0.00014523521531373262 0.0002175371046178043 9.032671368913725e-5 2.221999784524087e-5 0.0001249287452083081 9.412400686414912e-5 -1.9812303435173817e-5 6.028301868354902e-5 -8.606081973994151e-5 -2.8864629712188616e-5 0.00012067774514434859 -5.780879291705787e-6 -1.6801219317130744e-5 0.00028505668160505593 -6.47866545477882e-5 -2.762393341981806e-5 4.2717296310001984e-5 -3.872471643262543e-5 0.00014610582729801536 -3.7157180486246943e-5 0.00019049087131861597; 1.773810618033167e-5 -0.00011536067904671654 0.0001238585391547531 0.00010979478247463703 6.230253347894177e-5 8.897904626792297e-5 -0.00015178340254351497 6.839123670943081e-5 -0.00013320280413608998 -7.051991269690916e-5 3.8099282392067835e-5 -5.789395436295308e-5 0.00016538931231480092 -0.00011224256013520062 6.19039055891335e-5 0.00014622723392676562 0.00017222769383806735 -0.0001363351766485721 5.6025393860181794e-5 -0.00011948203609790653 0.00010875794396270066 2.5430395908188075e-5 5.499886174220592e-5 -4.95445347041823e-5 -5.036100264987908e-5 0.00018160544277634472 5.5177206377265975e-5 3.33007556037046e-5 -6.294114882621216e-6 0.0001879667688626796 -0.00017742003547027707 3.039529656234663e-5], bias = [-0.15440703928470612, 0.12235160917043686, 0.09640813618898392, 0.07331212610006332, -0.12968003749847412, -0.08275624364614487, -0.09084632992744446, -0.01504391711205244, 0.16256718337535858, -0.12451540678739548, -0.02188093587756157, 0.12437385320663452, -0.043862078338861465, -0.16444800794124603, -0.07131785899400711, -0.029454998672008514, -0.06485915184020996, 0.0014886056305840611, -0.021028580144047737, -0.034277062863111496, 0.05301447585225105, -0.07453048229217529, 0.12452807277441025, -0.03442782163619995, -0.0836026594042778, -0.1605084091424942, -0.15131723880767822, -0.05417707562446594, -0.06210154667496681, -0.012042914517223835, 0.10248588025569916, -0.13440975546836853]), layer_4 = (weight = [-4.230277772876434e-5 -6.918260623933747e-5 1.3117176422383636e-5 -5.567428161157295e-5 0.00010370484960731119 -9.793874778551981e-5 -6.373551877913997e-5 -0.00014454089978244156 2.5264243959099986e-5 1.7955206203623675e-5 9.576162847224623e-5 -1.515610347269103e-5 1.923370746226283e-6 -7.955553155625239e-5 0.00019175214401911944 -0.0002239962777821347 0.00016495282761752605 2.1663505322067067e-5 9.58572854869999e-5 2.2072455976740457e-5 -4.3864281906280667e-5 0.00011734620784409344 -1.2333764061622787e-5 -0.00017161565483547747 -4.227673707646318e-5 7.363154873019084e-5 -3.761819243663922e-5 5.103584044263698e-5 6.207667320268229e-5 -3.3743031963240355e-5 3.130464392597787e-5 -6.882468005642295e-5; -9.766876246430911e-6 -9.21195896808058e-5 -5.189688818063587e-5 -3.4824286558432505e-5 -8.882705151336268e-5 0.00013353988470043987 -8.032831829041243e-5 -6.433802627725527e-5 7.6721677032765e-5 7.503211236326024e-5 -8.202541357604787e-5 -9.680167568149045e-5 0.0002530438359826803 -1.34569108922733e-5 2.5403269319212995e-5 -6.149481487227604e-5 2.04203934117686e-5 -0.0001200764236273244 7.508493581553921e-5 0.00010756983829196543 1.1220228088859585e-6 -2.2631695173913613e-5 -2.0545807274174877e-5 6.605764792766422e-5 1.890482599264942e-5 -9.256904013454914e-5 -0.00010431870759930462 -0.00011025780986528844 0.00021316396305337548 3.539343742886558e-5 8.964975859271362e-5 8.890653407434002e-5], bias = [-0.04932510852813721, 0.028084194287657738]))Visualizing the Results
Let us now plot the loss over time
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")
lines!(ax, losses; linewidth=4, alpha=0.75)
scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)
fig
endFinally let us visualize the results
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(compute_waveform(
dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)
l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
s3 = scatter!(ax, tsteps, waveform_nn_trained; marker=:circle,
alpha=0.5, strokewidth=2, markersize=12)
axislegend(ax, [[l1, s1], [l2, s2], [l3, s3]],
["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
position=:lb)
fig
endAppendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 128 × AMD EPYC 7502 32-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 128 default, 0 interactive, 64 GC (on 128 virtual cores)
Environment:
JULIA_CPU_THREADS = 128
JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 128
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = LiterateThis page was generated using Literate.jl.