Using older version of Lux.jl
This tutorial cannot be run on the latest Lux.jl release due to downstream packages not being updated yet.
Training a Neural ODE to Model Gravitational Waveforms
This code is adapted from Astroinformatics/ScientificMachineLearning
The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl
Package Imports
using Lux, ComponentArrays, LineSearches, OrdinaryDiffEqLowOrderRK, Optimization,
OptimizationOptimJL, Printf, Random, SciMLSensitivity
using CairoMakie
Define some Utility Functions
Tip
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector
function one2two(path, m₁, m₂)
M = m₁ + m₂
r₁ = m₂ / M .* path
r₂ = -m₁ / M .* path
return r₁, r₂
end
one2two (generic function with 1 method)
Next we define a function to perform the change of variables:
@views function soln2orbit(soln, model_params=nothing)
@assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
if size(soln, 1) == 2
χ = soln[1, :]
ϕ = soln[2, :]
@assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2"
p, M, e = model_params
else
χ = soln[1, :]
ϕ = soln[2, :]
p = soln[3, :]
e = soln[4, :]
end
r = p ./ (1 .+ e .* cos.(χ))
x = r .* cos.(ϕ)
y = r .* sin.(ϕ)
orbit = vcat(x', y')
return orbit
end
soln2orbit (generic function with 2 methods)
This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d_dt(v::AbstractVector, dt)
a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
b = (v[3:end] .- v[1:(end - 2)]) / 2
c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
return [a; b; c] / dt
end
d_dt (generic function with 1 method)
This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d2_dt2(v::AbstractVector, dt)
a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)
Now we define a function to compute the trace-free moment tensor from the orbit
function orbit2tensor(orbit, component, mass=1)
x = orbit[1, :]
y = orbit[2, :]
Ixx = x .^ 2
Iyy = y .^ 2
Ixy = x .* y
trace = Ixx .+ Iyy
if component[1] == 1 && component[2] == 1
tmp = Ixx .- trace ./ 3
elseif component[1] == 2 && component[2] == 2
tmp = Iyy .- trace ./ 3
else
tmp = Ixy
end
return mass .* tmp
end
function h_22_quadrupole_components(dt, orbit, component, mass=1)
mtensor = orbit2tensor(orbit, component, mass)
mtensor_ddot = d2_dt2(mtensor, dt)
return 2 * mtensor_ddot
end
function h_22_quadrupole(dt, orbit, mass=1)
h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
return h11, h12, h22
end
function h_22_strain_one_body(dt::T, orbit) where {T}
h11, h12, h22 = h_22_quadrupole(dt, orbit)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
h11 = h11_1 + h11_2
h12 = h12_1 + h12_2
h22 = h22_1 + h22_2
return h11, h12, h22
end
function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
# compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2
@assert abs(mass1 + mass2 - 1.0)<1e-12 "Masses do not sum to unity"
h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
@assert mass_ratio≤1 "mass_ratio must be <= 1"
@assert mass_ratio≥0 "mass_ratio must be non-negative"
orbit = soln2orbit(soln, model_params)
if mass_ratio > 0
m₂ = inv(T(1) + mass_ratio)
m₁ = mass_ratio * m₂
orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
return waveform
end
compute_waveform (generic function with 2 methods)
Simulating the True Model
RelativisticOrbitModel
defines system of odes which describes motion of point like particle in schwarzschild background, uses
where,
function RelativisticOrbitModel(u, (p, M, e), t)
χ, ϕ = u
numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
denom = sqrt((p - 2)^2 - 4 * e^2)
χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = numer / (M * (p^(3 / 2)) * denom)
return [χ̇, ϕ̇]
end
mass_ratio = 0.0 # test particle
u0 = Float64[π, 0.0] # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4) # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize) # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e
Let's simulate the true model and plot the results using OrdinaryDiffEq.jl
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)
axislegend(ax, [[l, s]], ["Waveform Data"])
fig
end
Defiing a Neural Network Model
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model
that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but incase you do use them, make sure to mark them as const
.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl
,
const nn = Chain(Base.Fix1(fast_activation, cos),
Dense(1 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32),
Dense(32 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32),
Dense(32 => 2; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32))
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-7.6212375f-5; -6.828004f-5; 3.4645545f-5; 1.9029168f-5; -6.121949f-5; -6.8600064f-5; 7.876623f-6; -1.581312f-5; -0.000101037505; -5.185804f-5; -1.4994273f-5; 9.659224f-5; 0.0001628358; 6.454488f-5; 2.8993047f-5; -3.8898743f-5; -3.768681f-5; -6.453134f-5; -3.889051f-5; 0.00016972773; -4.7918198f-5; -0.00017601385; -6.921436f-5; 9.0903355f-5; 3.5885025f-6; -1.9706784f-5; 0.0001064235; 2.9860666f-5; 3.480168f-5; -4.4413697f-5; -0.0001593494; 4.4804172f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-2.9125753f-5 3.567069f-5 2.798114f-5 -4.7662405f-5 -5.764549f-5 -0.00011001285 1.0210566f-5 -0.00017521101 0.00019371795 -0.00023500848 5.2675525f-5 -0.00025270172 -1.981023f-6 -9.488267f-5 0.00017032781 -4.61919f-6 0.00023472053 -8.50571f-5 8.272881f-5 -0.00010589188 -0.00010010656 2.6129648f-5 -4.5024084f-5 6.530895f-5 -1.0581961f-5 -0.00019777074 -0.00011138917 -7.609641f-5 -6.981665f-5 -6.54895f-5 -0.00013258001 9.568346f-5; 8.200047f-5 6.6048706f-5 -0.0001481901 5.4275595f-5 5.7582365f-6 0.0001034996 4.9395323f-5 6.934915f-5 -0.00015803114 0.00013334687 7.9401885f-5 0.00020656642 -9.582842f-5 0.00020369531 7.637229f-6 8.156939f-5 6.34335f-5 -1.825247f-5 7.152902f-5 -9.965721f-5 -5.9072285f-5 7.816615f-5 -0.0001352522 0.00013552612 9.83203f-5 -4.686906f-5 5.205396f-5 1.2746583f-5 -0.00013009824 0.00013810657 0.00011642049 0.00019346754; 0.0001658512 7.6874916f-5 7.5848344f-5 3.203486f-5 -0.00020807736 -7.39607f-5 -0.000113695394 1.7584454f-5 5.6710483f-5 -0.000110271154 9.128635f-5 3.132791f-5 -6.4871485f-5 8.5387015f-5 0.00012529909 6.3194355f-5 -5.7577465f-5 0.000102258964 -0.00013979159 -0.000116966476 -7.450439f-5 1.2880439f-5 -4.7402733f-5 -0.00021483668 0.00012607299 -4.6009256f-5 2.2985305f-5 0.00011769362 7.429404f-5 -8.912102f-6 0.00012118743 1.5294734f-5; -0.00022166315 0.00019751045 -5.9034646f-5 -4.8447513f-5 -2.9117791f-5 -7.4056465f-5 0.00011228949 9.4382915f-5 -2.5821462f-5 4.1751864f-5 3.5742207f-5 -2.2179036f-5 -0.00012899567 -0.00020848302 4.4149605f-5 -6.080305f-5 0.000113438255 2.2923152f-5 -6.4959568f-6 1.3273528f-6 2.388089f-5 0.00013466296 -4.9967934f-5 -2.957405f-5 -4.2109536f-5 -0.00016464437 -0.00011346999 6.392958f-5 -6.616563f-5 8.523455f-5 -6.909384f-5 -4.7453774f-5; 6.0090257f-5 -2.6073007f-5 8.7179935f-5 -2.2071617f-5 -4.8804814f-5 -0.00014718303 -0.00010785581 -4.528939f-5 8.0438434f-7 5.6190616f-5 3.2122913f-5 6.165884f-5 1.3342547f-5 3.3086857f-5 -4.4416352f-5 -8.2099665f-5 -2.3152656f-5 0.00011316468 7.172059f-5 6.2786356f-5 4.5949324f-5 -0.00014296595 -0.00013794292 -9.612009f-5 -8.539212f-5 0.00018920859 0.00014142944 0.00026129355 -0.00012866023 -0.00014262056 -0.000110646375 -6.206323f-5; 1.0066262f-5 -3.953342f-5 1.8394563f-5 -0.000291189 -0.00011898481 -1.8667255f-5 3.621303f-5 -4.1135158f-5 5.532014f-5 -3.427537f-6 7.92196f-5 -0.00010108239 1.5418598f-5 -8.682844f-5 -1.3022334f-5 5.0930503f-5 -0.00022322504 0.000105777915 -0.00013875314 2.6582493f-5 3.054955f-5 -9.278947f-5 -0.00014793713 -4.0040595f-5 0.00023348269 9.737585f-7 -0.00014717113 -0.00017316452 1.4606237f-5 1.2407821f-5 7.2664494f-5 0.00028479882; 2.9944998f-5 0.00013000712 -2.7386093f-5 7.939813f-5 0.00016632903 6.019708f-5 -0.0003407532 2.2290727f-5 -8.26444f-5 -8.8584165f-5 9.315455f-6 8.146916f-6 7.5862376f-6 5.543574f-5 0.00015039626 6.1695486f-5 -6.276561f-5 3.585618f-5 -2.0422927f-5 1.1454865f-5 -7.076048f-5 0.00014152362 -5.734965f-5 -7.999866f-5 3.9625716f-5 -2.573205f-5 -9.158882f-5 9.780686f-5 4.3795637f-5 7.960357f-5 -0.00020796164 2.0895406f-5; 5.598297f-7 7.6723285f-5 1.4974934f-5 0.00014928909 1.8446683f-5 -1.4770298f-5 -2.1648325f-6 -0.00013367669 1.3181045f-5 1.9742996f-5 -0.00011696348 4.6901155f-6 -0.00012138633 -6.807603f-5 0.0001942279 -3.264488f-5 -0.00015237676 1.4479009f-5 3.9723353f-5 0.00019861515 -2.3650591f-6 1.8770477f-5 -3.1722007f-5 3.9401326f-5 7.5329095f-5 -3.8095513f-6 0.0001310157 0.000153429 8.939615f-5 1.389927f-5 4.7457866f-6 -6.821081f-5; -1.5144509f-5 -7.0669696f-5 -4.15517f-5 2.3325158f-5 4.0773335f-5 8.925939f-5 -0.00013984762 -0.00013316998 0.00010203136 4.5510966f-5 1.4856112f-5 1.0743353f-5 6.7575485f-5 -0.00010482243 -5.0789467f-6 -0.00016010048 -0.00014118834 -8.9530964f-5 7.639962f-5 0.00013971783 4.3932964f-6 0.00013410911 -9.007911f-5 0.00033075147 -0.00017155166 -0.00011733521 0.00014584677 -3.0750412f-5 -5.750395f-7 -0.0001042372 0.00012930458 -0.00020939486; -3.5898407f-5 -0.00016331606 0.00010011738 -6.8505185f-5 -4.1943527f-5 8.876543f-5 -0.00010751347 -0.00012764103 3.8836515f-5 -8.653255f-5 -4.6427238f-5 -0.0001085046 -3.0680992f-5 2.9756153f-5 -0.00029465498 4.4805744f-5 -4.413957f-5 -7.726167f-5 -5.9095732f-6 -2.8364075f-5 -3.120865f-6 9.828517f-6 7.3201736f-5 -2.6330412f-5 -3.189111f-5 -7.4346084f-5 8.24358f-5 -0.00010816293 -0.00011648428 5.755484f-5 -6.7414294f-7 0.00010446791; -0.0001171122 3.172959f-6 -8.721566f-5 -5.797071f-5 -0.00016205876 -0.00011344688 -6.914407f-5 2.3119648f-5 -3.3227232f-5 -8.52072f-5 -5.8017995f-5 0.00010394976 4.325106f-5 -0.00010775347 -4.2851654f-5 -0.00017612056 0.00020678234 -8.837138f-5 -0.000103192695 -6.4222346f-5 -8.537329f-5 -0.00013041195 0.0001159817 5.05721f-5 -9.510397f-5 3.5480025f-5 -0.000115547075 -0.00016506454 3.6180998f-5 0.00010477451 -0.00010511107 0.00012490455; -2.44358f-5 -6.377469f-5 -3.679295f-5 1.3863759f-5 5.622315f-5 -1.4299998f-5 -2.996321f-6 -0.00014427233 9.961688f-5 -2.7210465f-6 9.952784f-5 2.0772784f-5 0.0001815766 -0.00010266579 -0.00014838326 -6.374809f-5 -0.00017822959 8.358795f-5 -5.4465858f-5 -0.00021779082 -1.5651098f-5 -3.3638094f-5 -7.4766785f-7 6.0352613f-5 -6.0760918f-5 2.4883626f-5 -2.9060164f-5 3.6877344f-5 -5.3767642f-5 -5.4370852f-5 0.00012313115 -0.000109267836; 6.618383f-5 4.8632442f-5 6.933369f-5 3.7330712f-5 -3.8223236f-5 -0.00022490633 3.0920155f-6 9.432528f-5 -0.00012147321 0.00010349434 4.2920394f-5 1.5970856f-5 0.0001499698 3.6488007f-5 4.560448f-5 0.00016086116 -0.00012663515 7.900059f-5 -0.00014690855 -0.00012629192 0.00011125174 -0.00023642786 -4.210797f-5 8.4481995f-5 5.2535f-5 -0.000114589035 5.4153694f-5 -0.0001535481 -1.3776468f-5 0.00013601469 2.2539827f-5 0.00014900966; -3.167169f-5 0.00020821282 -5.45953f-5 -0.00022239787 -4.5098044f-5 7.490127f-5 2.8284663f-5 -5.3857653f-5 -8.718302f-5 6.4951186f-5 -0.000104586274 -3.5932712f-6 7.196917f-5 -6.0122926f-5 -7.257043f-5 -0.00010184014 0.0001251441 0.00012656506 -5.6977908f-5 0.00014042002 -2.3508467f-6 -1.5926915f-5 -3.3582655f-5 2.6494123f-5 2.3003515f-5 -4.1816522f-5 0.00015351892 -5.1504965f-5 2.380429f-5 -0.00011271229 9.494021f-6 9.184913f-5; -0.0001083801 4.3394255f-5 0.00010760798 6.9283764f-5 6.0720093f-5 -1.683407f-5 -6.0037954f-5 0.00019400126 9.362527f-5 -5.8040645f-5 -0.00010613625 5.157354f-5 0.00017564878 -1.610638f-5 -5.613919f-5 2.9308803f-5 -5.4047345f-5 -1.2899073f-5 4.6277484f-5 -6.8871054f-7 7.442494f-5 -9.1421556f-5 9.872121f-5 -6.323421f-5 2.8107448f-5 7.302721f-5 -1.0899094f-5 -4.7039637f-5 5.8980448f-5 -0.00013281302 8.782015f-6 -3.9147984f-5; 1.5706723f-6 -2.7996723f-5 -9.498093f-5 0.000109707595 4.7912235f-5 9.717209f-5 -2.2540868f-5 -9.561277f-6 5.7299647f-5 0.00026008242 -4.0616942f-5 -8.717659f-6 -0.00018270877 -0.00013285526 2.3345729f-5 7.565327f-5 1.1031965f-6 5.359386f-5 -5.601795f-5 -4.7394093f-5 -0.00018008861 8.3041865f-5 1.545314f-7 -1.6812432f-5 7.95618f-5 -0.00011437501 -7.233577f-5 -7.2111165f-5 9.066687f-5 0.00016970745 7.184294f-5 -9.6974494f-5; -0.00013360067 7.742072f-5 8.3122846f-5 -8.769131f-5 -6.2281375f-5 5.4089214f-5 1.7891985f-5 -9.436573f-5 -0.00013061501 6.3497486f-5 2.1362282f-5 3.0720163f-5 2.443027f-6 2.695473f-6 -9.6804804f-5 -5.7462734f-5 5.758325f-5 -6.5752385f-5 7.703256f-5 -0.00026648722 0.00013761534 -0.00013671523 0.00010723095 -6.8089044f-5 8.309701f-5 0.00020986267 -0.00020950257 -0.00020831394 -3.45306f-5 8.711913f-5 3.946406f-6 -7.313652f-5; 0.00010100545 -4.081083f-5 7.491839f-5 -1.5825067f-6 -0.00010688496 -5.6372177f-5 5.207666f-6 0.00023620039 -2.6466238f-5 0.00020514627 0.00015309223 -8.068107f-5 0.00014763558 3.5357425f-5 -3.0119196f-5 -5.662239f-5 3.947805f-5 0.0001234229 -5.5911438f-5 -1.6000926f-5 0.00012827948 -9.674602f-6 -4.171722f-5 2.8406328f-6 6.952296f-5 9.334583f-5 -5.8535952f-5 6.437936f-5 1.4307074f-6 -4.2413456f-5 -2.3954144f-5 2.3093455f-5; -4.6510708f-5 -8.1432496f-5 -4.9553877f-5 -6.8498535f-5 5.3020394f-5 -0.00014837127 -2.4993396f-5 3.4961697f-5 8.736166f-6 9.874889f-5 -4.1380186f-5 6.217019f-6 6.334951f-5 -6.730708f-5 -7.2388444f-5 -5.457352f-5 -0.00018984762 0.0001353767 -4.3204862f-5 5.3426455f-5 -1.34804895f-5 4.6562738f-5 0.000110889385 -0.00014827335 0.00015864453 3.0137458f-5 -8.451129f-5 8.211706f-5 1.8544728f-5 -3.88849f-5 -0.00019242574 3.7238362f-5; -0.00019912525 -1.0916288f-5 -9.6232376f-5 -6.760549f-5 2.4592107f-5 7.932058f-5 1.8410954f-5 0.00010437145 0.00023274387 -5.174045f-5 4.5673853f-5 2.4164596f-5 1.9551562f-5 -0.00026943363 8.186278f-5 7.183995f-5 -4.7618138f-5 -5.239773f-6 0.00015269393 0.0001862971 -0.00020666842 4.550079f-5 6.542024f-5 -3.0407848f-5 -0.00021665079 -0.00014270771 0.00015411052 0.00011561245 -0.00011685554 2.6520755f-5 -0.00026733216 -5.9419897f-5; 0.00012010509 4.8258054f-5 -8.316388f-5 1.2209842f-5 0.00029047197 -5.3456286f-5 -3.0442865f-5 -5.1728428f-5 -6.692857f-5 3.0537012f-5 5.8678976f-5 7.82564f-6 6.166527f-5 -3.6973324f-5 3.3456854f-5 0.000119530414 -3.4881377f-5 2.5762563f-5 1.6980868f-5 2.9895005f-5 7.2188835f-5 -7.841355f-5 -1.3627159f-5 -8.087871f-5 -0.00011663256 0.00017902062 -5.9551603f-5 8.986628f-5 -0.000114199975 3.9008473f-6 9.221502f-5 6.611498f-5; -4.1724506f-6 -5.3825956f-5 8.820635f-5 -2.5777932f-5 -4.2626765f-5 -6.868209f-5 -4.7789155f-5 6.6007466f-5 -0.00011848123 2.0231491f-6 -4.9856517f-5 5.071172f-5 7.053148f-6 0.0001147518 -3.096124f-5 -0.0001786954 2.8384902f-5 -4.2118256f-5 1.3875376f-5 -2.6917975f-5 -2.4860068f-5 -0.00014146419 7.015344f-6 -3.907246f-5 -1.7840914f-5 -9.151608f-5 0.0001311089 0.00020495891 0.0001339814 -7.018391f-5 -0.00013829491 -8.459504f-5; 0.0001062027 7.520745f-5 -7.5464006f-5 -7.820129f-5 5.4692413f-5 -0.00018291273 -0.00016114942 3.654859f-5 -6.922817f-5 5.0209815f-7 -7.88969f-5 0.00010629573 1.3591315f-5 -8.8694396f-5 -9.700923f-5 -7.792265f-5 -5.3244665f-5 2.5566405f-5 9.5518946f-5 -8.606936f-5 1.5010637f-5 1.6346119f-5 -0.00020224716 9.9279954f-5 -0.00017092249 -4.1199684f-5 6.1716986f-5 -1.828571f-5 -6.4742904f-5 6.141311f-5 3.665967f-5 0.00015882637; -5.3134157f-5 3.571111f-5 -0.0001071512 5.9995502f-5 -2.3541332f-5 -7.981617f-5 -7.268167f-5 2.2832386f-5 -5.2762174f-5 4.7423026f-5 -0.00010546918 0.00014138562 -4.2545147f-5 -1.770352f-5 8.389373f-5 -3.127577f-5 -5.382556f-5 -3.805856f-5 -4.9698865f-5 6.5461325f-5 -0.00017162041 6.930798f-6 -4.481219f-6 -0.00011541513 3.3884542f-5 -0.00026558308 -5.8299946f-5 -0.00010096692 0.00028417318 0.00022399446 -1.8742303f-5 0.00015314913; -2.8444047f-5 1.6567892f-5 3.2681437f-5 -1.8588718f-5 -7.592697f-5 0.00013262371 -8.01927f-5 3.0489913f-5 -0.00020825998 9.213935f-5 -0.00015984585 0.00013850979 0.00012299787 1.7118593f-5 -3.0244519f-6 -9.4906696f-5 -4.7385594f-5 -6.4886815f-5 5.1190942f-5 -7.8480756f-5 -6.476434f-5 -3.078301f-5 0.00014187123 -0.00014201117 4.8327005f-5 -5.749435f-5 -3.0272287f-5 -0.00010431894 -0.00027794487 8.457921f-5 -3.544488f-5 6.957566f-5; 0.00012405527 7.090927f-5 -9.547093f-5 -5.624749f-5 -2.8614442f-5 -0.00025870837 0.00012733704 7.742453f-5 -0.00011473053 -3.7116606f-5 -4.32381f-5 -4.44775f-5 -5.0753108f-5 7.786667f-5 0.00017440702 -0.000102921134 -0.00019033113 -0.00013119761 -9.082414f-5 5.5319182f-5 -5.6503566f-5 2.7747428f-5 -8.560285f-5 -5.90994f-5 -7.0763475f-5 7.0684524f-5 0.00011613276 -1.8678265f-6 4.7004698f-5 -0.00014711992 4.538411f-6 6.597061f-5; -2.1668966f-5 -0.00013806454 0.00015222996 -0.00017829033 -0.00013861722 -8.871638f-5 5.350655f-5 -2.0491183f-5 -0.0001419638 7.500167f-5 -0.00014882827 5.4925595f-6 -0.00013733977 0.00014466101 -0.0001348514 0.00011863919 -3.338715f-5 -4.252546f-5 -6.355367f-5 -5.5923785f-5 -6.151603f-6 5.9459806f-5 0.00010971954 3.074991f-5 8.014728f-5 -0.00012124516 -0.00011530138 0.00017042816 -0.00018529152 0.0001515527 -4.7749712f-5 3.6865862f-5; -1.2601707f-5 -3.5469624f-5 8.9262685f-5 3.5632118f-5 -8.335018f-5 0.00011051867 -6.562574f-5 0.00010391501 7.2400006f-5 -7.732753f-5 6.6265064f-5 -3.5518522f-7 -4.5292258f-5 9.174502f-5 2.1575612f-5 -0.00012630774 -1.3433603f-5 7.605811f-5 2.9383744f-5 -0.0002371179 9.140218f-6 -5.2299347f-5 3.4501857f-6 -2.8361754f-5 4.0957453f-5 2.6299056f-5 -0.0001228758 -7.60514f-5 -7.48981f-5 -0.00010468629 -0.00016262839 8.330861f-6; 0.00017458422 -6.526823f-5 3.9015133f-5 -9.070606f-6 3.0546056f-5 1.615465f-5 -5.4537f-5 0.0001961071 0.00011681229 8.063602f-5 -4.0023675f-5 9.792304f-5 0.00010737351 -5.1229086f-5 0.00020383243 8.370374f-5 7.5176445f-6 -0.00013720997 7.720838f-5 -0.00014382745 0.00011211199 -8.811881f-5 2.6984422f-5 -5.660796f-5 -0.00020003763 -2.5203923f-5 0.00023427531 -9.029806f-5 3.3874494f-5 4.465718f-5 -0.00012938977 -8.3209605f-5; 0.00015996037 -8.6771564f-5 -9.360413f-6 7.602068f-5 -3.4988505f-5 -0.00011582998 6.3358035f-5 0.00024041132 0.00018234852 1.975689f-5 -2.228824f-6 9.244541f-6 0.00010146939 2.678402f-5 6.213541f-5 -9.078045f-5 1.1416683f-5 0.00022877238 -9.4782445f-5 -2.8187222f-5 3.0114803f-5 4.885755f-5 -0.00014531161 6.229832f-5 -1.7601025f-5 0.0001012602 -8.630344f-5 6.962805f-5 -0.00011904664 3.84476f-5 0.00015080922 -6.07305f-5; 2.5019708f-5 -4.308669f-5 3.2099854f-6 -1.070475f-5 -7.1183626f-5 1.7475644f-5 -0.00013197928 9.631055f-5 0.000149675 2.0617274f-5 7.940401f-5 -4.4641656f-5 -5.8994356f-5 8.903909f-5 5.7457415f-5 0.00017483684 -2.53013f-5 -1.49022635f-5 -6.557858f-5 0.0001592803 -0.0001140723 -8.931739f-5 -8.774388f-5 -3.820522f-5 -0.00023867318 5.9817f-5 -0.00013119371 0.00011898252 2.6324933f-5 -0.0001758107 -6.6001645f-5 6.313187f-5; -0.00015309908 0.0001426998 -3.2693846f-5 -1.3040735f-6 0.00014535329 -0.00012739559 -3.2720258f-5 3.3698423f-5 0.00020791453 5.1849816f-5 2.0268108f-6 6.429721f-5 6.0508104f-5 -0.00019773634 0.000108098415 -1.4077876f-5 -0.00017357737 -7.1772592f-6 -7.6222386f-5 0.00014064147 6.149977f-5 -7.8262856f-5 1.8839732f-5 0.000110933724 4.835375f-5 2.5659114f-5 -0.00014640919 -2.8351891f-5 -0.00016531482 -0.00025105465 0.00020287237 0.0002243887], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[1.5997928f-5 0.00014728575 0.00017078283 -5.6750403f-5 -6.646122f-5 0.00024876773 0.00016268472 2.6614085f-5 0.00010973593 -0.00010702466 -8.039563f-5 -0.00013119199 -7.560695f-5 -0.00011244174 -2.0466572f-5 3.2376218f-5 -0.00021143115 7.41661f-5 -4.4710916f-5 -5.727751f-6 -2.3214756f-5 6.2922656f-5 -5.0656705f-5 -4.7631514f-5 -9.438545f-5 9.277793f-5 -9.024292f-5 -5.6760793f-5 -0.00011689727 -2.582012f-5 6.588881f-5 1.4327455f-5; 0.00022048381 -0.00014130473 -9.217048f-5 6.7877205f-5 -7.90706f-5 6.5924555f-6 6.0525726f-6 0.00018886756 9.22664f-6 -0.00010882886 -2.2435765f-5 2.5857107f-6 4.886121f-5 -0.0001666136 -8.012143f-5 5.6968016f-5 0.00023357538 1.651197f-5 -5.8851616f-5 -0.00012827548 3.467483f-5 6.9379974f-5 1.5464773f-5 -0.00014696848 -5.707658f-5 -9.530237f-6 9.141111f-5 7.101698f-5 0.000109753506 0.00014069243 7.841061f-6 1.36926865f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))
Similar to most DL frameworks, Lux defaults to using Float32
, however, in this case we need Float64
const params = ComponentArray(ps |> f64)
const nn_model = StatefulLuxLayer{true}(nn, nothing, st)
StatefulLuxLayer{true}(
Chain(
layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
layer_2 = Dense(1 => 32, cos), # 64 parameters
layer_3 = Dense(32 => 32, cos), # 1_056 parameters
layer_4 = Dense(32 => 2), # 66 parameters
),
) # Total: 1_186 parameters,
# plus 0 states.
Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses
where,
function ODE_model(u, nn_params, t)
χ, ϕ = u
p, M, e = ode_model_params
# In this example we know that `st` is am empty NamedTuple hence we can safely ignore
# it, however, in general, we should use `st` to store the state of the neural network.
y = 1 .+ nn_model([first(u)], nn_params)
numer = (1 + e * cos(χ))^2
denom = M * (p^(3 / 2))
χ̇ = (numer / denom) * y[1]
ϕ̇ = (numer / denom) * y[2]
return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)
Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)
axislegend(ax, [[l1, s1], [l2, s2]],
["Waveform Data", "Waveform Neural Net (Untrained)"]; position=:lb)
fig
end
Setting Up for Training the Neural Network
Next, we define the objective (loss) function to be minimized when training the neural differential equations.
const mseloss = MSELoss()
function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
return mseloss(pred_waveform, waveform), pred_waveform
end
loss (generic function with 1 method)
Warmup the loss function
loss(params)
(0.0007166245597729142, [-0.02426909472654956, -0.023483915795636957, -0.022698736864724598, -0.021372047154869133, -0.01947610484549094, -0.016970642081981784, -0.013802792177240357, -0.009904128868986565, -0.0051929403610022325, 0.0004261088002610532, 0.007052742079547345, 0.014770056229519664, 0.023596753232773496, 0.033358370984896274, 0.04340810524799511, 0.051966781995033316, 0.05472705829685283, 0.04258030712349483, 0.001969840206789606, -0.0663821577085762, -0.11040039149596034, -0.07624712082066967, -0.006685463780116319, 0.03895416936684998, 0.05435219369819673, 0.05296804126750264, 0.04481683589664446, 0.0347711081503614, 0.02486718491127757, 0.015861972213674684, 0.007968474574375146, 0.0011846578057420985, -0.00457066484396455, -0.009398730207400518, -0.01339777196684956, -0.01665190156779149, -0.019232595054077276, -0.021195674463974377, -0.022583500891852626, -0.023425879970726524, -0.023741045469850237, -0.023536017507062308, -0.02280788655075396, -0.021541478139634408, -0.01971148249001068, -0.01728011032142346, -0.014196344958884125, -0.010395394903554473, -0.005797936691822398, -0.00031169339105488757, 0.006161011424379674, 0.013707210863452841, 0.02235467571069252, 0.03196754087302741, 0.041991654166337586, 0.05087612416358229, 0.05481885150067479, 0.04558706893637835, 0.009889323862385975, -0.056073031646418046, -0.10846917882619708, -0.08536910374982667, -0.016020327498146136, 0.034657758070166765, 0.05365321912965863, 0.05386022687139579, 0.04621110920721561, 0.03620317589105212, 0.026166367043139144, 0.016983038834965564, 0.008908763712356303, 0.0019659612872822995, -0.003932199338246643, -0.008878063611560644, -0.012981537471392576, -0.0163222931547283, -0.018981626271723888, -0.021011638421713177, -0.022462378265797986, -0.02336180683155706, -0.023732664508755028, -0.023582903638458635, -0.02291041574636718, -0.02170390009658089, -0.019939768435586425, -0.017579290349309853, -0.014579806122271647, -0.010872014961370309, -0.00638742758276175, -0.0010284929512729838, 0.005292738958876385, 0.012672391829574157, 0.021140786718434793, 0.030600695671734433, 0.04057288873983958, 0.04971274466629215, 0.05466719175201668, 0.04802989899961981, 0.01703819525477592, -0.045595278634280464, -0.1046794646244152, -0.09346402307776028, -0.025924632298089098, 0.02964053208126005, 0.05258304765466304, 0.05462162996904037, 0.04758211153593092, 0.03765153138364236, 0.02749639407818875, 0.0181304977735022, 0.00987882920533422, 0.002765534021186933, -0.0032716733440376375, -0.0083457088127025, -0.012550191690961401, -0.01598528588285112, -0.018719943354969203, -0.020822138554170665, -0.022333604536539046, -0.023292599798174364, -0.023718498955245762, -0.023623511971377673, -0.023007974794918624, -0.02185952588513882, -0.020159639574962144, -0.017870307366408617, -0.01495145947810392, -0.011336419783660266, -0.006959990245967133, -0.0017266548921875192, 0.004448605043307346, 0.011663534675768701, 0.019956769685198456, 0.02925697982255306, 0.039158574965256925, 0.048491394005046616, 0.05430369131317174, 0.04996730680320123, 0.02340427411641193, -0.03518530347487399, -0.09919046037619889, -0.10024592018984436, -0.036259770346072626, 0.023860164840428123, 0.05109456218020804, 0.055225885047218426, 0.048920375570612726, 0.03911556313685846, 0.02885038265125619, 0.0193129850494142, 0.010871980950291875, 0.0035918266127311113, -0.002598370024449912, -0.0077922104261353466, -0.012109365390869814, -0.015635374631353605, -0.018453281036711455, -0.020622015983614902, -0.022199816182698846, -0.023216921182854948, -0.02369886181311036, -0.023658864647614145, -0.02309880832759219, -0.022009557301210632, -0.020369920125118898, -0.01815480936007966, -0.015311679581422816, -0.011787116273986027, -0.0075179001799075, -0.0024073609801513367, 0.003628241800648, 0.010682229517521827, 0.018799366055336678, 0.02793862078076006, 0.037752425608216435, 0.04722402345150401, 0.05376270954417922, 0.051447732437467635, 0.02901021775897267, -0.02506617176865673, -0.0922016811903469, -0.10547056012477442, -0.046847214118606394, 0.017295338797052143, 0.04912751792902425, 0.05564781268856679, 0.0502142897092844, 0.04058750330669737, 0.030238439316673044, 0.020520662252582266, 0.011894308881488973, 0.004437829556465965, -0.001897345710280913, -0.00722845264665199, -0.01165453586419174, -0.015277056137410474, -0.018172320853735897, -0.02041796777527581, -0.02205880559213478, -0.023135562645870154, -0.023673095476275222, -0.02368862793554995, -0.02318386687935065, -0.022151210921666702, -0.02057756313149454, -0.01842710212867405, -0.015662604422646668, -0.012226452626293717, -0.008059759549263863, -0.0030686503620924966, 0.0028289100429126014, 0.009725016541336586, 0.017671784342568442, 0.026646333536206563, 0.036357129237817594, 0.045922697785313035, 0.05306853541988115, 0.05252468138542487, 0.03388015153770523, -0.015396551524424925, -0.08397683317914853, -0.10894432792469566, -0.05745418334690027, 0.009924885117599576, 0.046634050869480234, 0.05585191696304461, 0.05145061821548393, 0.04206770457981179, 0.0316494264933922, 0.021761231934514628, 0.012944684773957634, 0.0053104485317720775, -0.001180703306406706, -0.006646113103186789, -0.011186077220489428, -0.014906205261506345, -0.01788538116510913, -0.020204571155770883, -0.0219110912046049, -0.023047962226187838, -0.023642004079472926, -0.02371258816383115, -0.023263086034678405, -0.022287825933534015, -0.020775399687206646, -0.018692399488553366, -0.01600351080099226, -0.012653830060522146, -0.00858712275828064, -0.004520415456039062])
Now let us define a callback function to store the loss over time
const losses = Float64[]
function callback(θ, l, pred_waveform)
push!(losses, l)
@printf "Training %10s Iteration: %5d %10s Loss: %.10f\n" "" length(losses) "" l
return false
end
callback (generic function with 1 method)
Training the Neural Network
Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
optprob, BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
callback, maxiters=1000)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-7.621237455152889e-5; -6.828003824917514e-5; 3.464554538369519e-5; 1.9029168470253102e-5; -6.121949263605459e-5; -6.860006396885529e-5; 7.876623385521802e-6; -1.581312062623181e-5; -0.00010103750537382344; -5.185803820486599e-5; -1.4994273442413124e-5; 9.659223724147794e-5; 0.00016283580043812717; 6.454488175213905e-5; 2.8993046726056413e-5; -3.88987427867797e-5; -3.768680835491772e-5; -6.453134119502595e-5; -3.8890510040795306e-5; 0.00016972773300933213; -4.791819810628294e-5; -0.00017601385479783352; -6.921435851843492e-5; 9.090335515788435e-5; 3.5885025226863496e-6; -1.970678385984691e-5; 0.00010642349661780524; 2.98606664727743e-5; 3.480168015810646e-5; -4.441369674168386e-5; -0.00015934939437979527; 4.4804171921023695e-5;;], bias = [-1.7642498311229118e-16, -4.3367528244030375e-17, -1.4052422225800365e-17, 4.855026388306068e-17, -1.1207343822149232e-16, -6.655275901843062e-17, 6.618567315521975e-18, -3.385037472571622e-17, -1.7154765017903173e-18, -9.811237776717042e-17, -1.4308052273011417e-17, 9.587226528077357e-17, 6.623367842556819e-17, 9.759394868789597e-17, 4.939623549518107e-17, -6.911688628888603e-17, -2.901058698251685e-18, -8.902872085454367e-17, -6.5958990762675894e-18, 1.79466220416824e-16, -6.89266578966841e-17, -1.0658366144097241e-16, 7.807953708448233e-17, 5.623635898134213e-17, 2.9267887855927683e-18, -3.010556611000354e-17, 1.4346655254043104e-16, 4.395296879445218e-17, 2.120413145370029e-17, -6.687402421782222e-19, -3.5551818220859137e-16, -4.0663766628699344e-17]), layer_3 = (weight = [-2.912962807717323e-5 3.566681627158342e-5 2.797726412277498e-5 -4.766627970061541e-5 -5.7649366484515725e-5 -0.00011001672536044112 1.0206690800577987e-5 -0.00017521488624302033 0.0001937140725352619 -0.00023501235844709893 5.2671649624562684e-5 -0.00025270559561580493 -1.984898093372532e-6 -9.48865461247301e-5 0.00017032393517761273 -4.6230650547369436e-6 0.00023471665482825458 -8.506097654721498e-5 8.272493434228803e-5 -0.00010589575393497428 -0.00010011043630915728 2.6125773016604112e-5 -4.5027959262046e-5 6.530507672901985e-5 -1.0585836010841593e-5 -0.00019777461855815067 -0.00011139304550273169 -7.610028549938671e-5 -6.982052289790642e-5 -6.549337266155592e-5 -0.00013258388642826965 9.567958769944426e-5; 8.200530932146308e-5 6.605354350387477e-5 -0.00014818526038805056 5.4280432903693086e-5 5.763074239730274e-6 0.00010350443984709102 4.9400160731534956e-5 6.935399063725136e-5 -0.00015802629865146743 0.00013335170582497229 7.940672288743896e-5 0.00020657125345524972 -9.582358403505989e-5 0.0002037001460287857 7.6420666467354e-6 8.157423066070221e-5 6.343833878265625e-5 -1.8247633180800167e-5 7.153386026250116e-5 -9.96523748305871e-5 -5.906744705651159e-5 7.817098879811364e-5 -0.00013524735628532838 0.00013553095699381817 9.832513673093317e-5 -4.686422185993661e-5 5.205879928171895e-5 1.2751420525387067e-5 -0.00013009340260548553 0.00013811140446552854 0.00011642532548492716 0.00019347237606627708; 0.00016585175339061873 7.687546471063175e-5 7.584889257488959e-5 3.2035407766692496e-5 -0.00020807680690266048 -7.396014976618686e-5 -0.00011369484551321506 1.758500306097994e-5 5.671103221177153e-5 -0.00011027060523699808 9.128689574938472e-5 3.132845753500283e-5 -6.487093620306138e-5 8.538756386771934e-5 0.00012529964018542515 6.319490361422483e-5 -5.757691602409635e-5 0.00010225951276881756 -0.0001397910424748045 -0.00011696592688165902 -7.450383842294868e-5 1.288098732404074e-5 -4.740218415461357e-5 -0.00021483612787049201 0.0001260735401411003 -4.6008707304259606e-5 2.298585362952677e-5 0.00011769416900281057 7.42945879574325e-5 -8.91155346832996e-6 0.00012118798198573054 1.5295282846658355e-5; -0.00022166504454414286 0.0001975085634078179 -5.9036536138385744e-5 -4.8449402991330976e-5 -2.9119681082894605e-5 -7.405835453734726e-5 0.00011228760017932116 9.43810246574014e-5 -2.582335214153052e-5 4.174997416360187e-5 3.57403169366318e-5 -2.2180925905347684e-5 -0.00012899756070306113 -0.00020848490926052467 4.41477150636516e-5 -6.08049394466188e-5 0.00011343636474723735 2.2921261963555333e-5 -6.497846650538988e-6 1.3254629513953344e-6 2.387899990228796e-5 0.00013466106797954759 -4.99698234741761e-5 -2.9575940013972647e-5 -4.211142545298023e-5 -0.0001646462605529211 -0.00011347188223932459 6.392769323813503e-5 -6.616752029710655e-5 8.523265755886984e-5 -6.909572943891984e-5 -4.745566361432121e-5; 6.0089400576719764e-5 -2.607386282744036e-5 8.71790786188211e-5 -2.2072473584383628e-5 -4.8805669721157185e-5 -0.00014718388898455224 -0.00010785666869179156 -4.5290247144119746e-5 8.035281804901371e-7 5.6189760143927106e-5 3.2122056509425925e-5 6.165798062301685e-5 1.3341691228742603e-5 3.308600084000568e-5 -4.4417208627972515e-5 -8.210052088525849e-5 -2.3153511719836457e-5 0.00011316382246167701 7.171973506064383e-5 6.278549939470417e-5 4.594846781418909e-5 -0.00014296680215902517 -0.00013794377206050474 -9.612094923778022e-5 -8.539297901716826e-5 0.00018920772949076453 0.00014142858221038746 0.00026129269339452706 -0.00012866109078441527 -0.00014262141245108785 -0.00011064723126746755 -6.206408339476903e-5; 1.0063723915722347e-5 -3.953595890919759e-5 1.839202510751421e-5 -0.00029119154217909524 -0.00011898734968411893 -1.866979304025958e-5 3.6210492420704216e-5 -4.1137695494345135e-5 5.531760094897259e-5 -3.430074944789023e-6 7.921706261555223e-5 -0.00010108492873399219 1.5416060217905985e-5 -8.68309785218456e-5 -1.3024872292535569e-5 5.092796478862536e-5 -0.00022322757880459886 0.0001057753774730942 -0.00013875567541248504 2.657995488706064e-5 3.054701070510509e-5 -9.279201052405675e-5 -0.00014793966559949404 -4.004313317267013e-5 0.00023348014789376879 9.71220482878644e-7 -0.00014717366733378267 -0.00017316705501429193 1.4603698640767131e-5 1.2405282599557452e-5 7.266195595009927e-5 0.00028479628368955733; 2.9945377250602724e-5 0.00013000749503151086 -2.7385714488284776e-5 7.939850955624692e-5 0.00016632941013556486 6.019745750125071e-5 -0.0003407528270279933 2.2291106220270512e-5 -8.264402232709833e-5 -8.858378670522412e-5 9.315833708730998e-6 8.147294907008592e-6 7.586616434467788e-6 5.543611813225363e-5 0.00015039663456734823 6.169586457433724e-5 -6.276523301023444e-5 3.5856559848231715e-5 -2.042254847294573e-5 1.1455243552676189e-5 -7.076009889244498e-5 0.00014152399715855902 -5.7349272096710834e-5 -7.999828041176781e-5 3.962609524494599e-5 -2.573167104385236e-5 -9.158844433370555e-5 9.78072352929096e-5 4.37960156752317e-5 7.96039462194832e-5 -0.00020796126001626 2.089578491466192e-5; 5.618348905603136e-7 7.67252902233505e-5 1.4976939151042299e-5 0.00014929109306200663 1.844868779249991e-5 -1.4768292435235605e-5 -2.162827290644204e-6 -0.0001336746820515461 1.3183049982270105e-5 1.974500150487398e-5 -0.00011696147271771649 4.69212067860899e-6 -0.00012138432362458251 -6.807402246971607e-5 0.00019422991021662245 -3.264287375253443e-5 -0.00015237475168299673 1.4481014427531133e-5 3.9725357733063356e-5 0.00019861715258688938 -2.3630539134830213e-6 1.8772482458228562e-5 -3.172000220268997e-5 3.9403331125012814e-5 7.533110032915418e-5 -3.807546140646807e-6 0.00013101770013696057 0.00015343101108107008 8.939815301806869e-5 1.3901275219781443e-5 4.74779175878875e-6 -6.820880230855978e-5; -1.5145573428545878e-5 -7.067075975466818e-5 -4.155276300846085e-5 2.3324093404859502e-5 4.077227086416117e-5 8.925832472048403e-5 -0.00013984868027730368 -0.00013317103920964976 0.00010203029333386461 4.550990152527211e-5 1.4855047907964477e-5 1.074228921968817e-5 6.757442125125306e-5 -0.00010482349361520575 -5.080010936602123e-6 -0.0001601015411270264 -0.00014118940643495357 -8.953202822706046e-5 7.639855712657389e-5 0.00013971676335676166 4.3922322274531925e-6 0.00013410804777412347 -9.008017704583999e-5 0.0003307504065614735 -0.00017155271970661814 -0.00011733627079049567 0.00014584570990880885 -3.075147660094741e-5 -5.761037018811685e-7 -0.00010423826651642488 0.0001293035146827478 -0.0002093959211324753; -3.590221383981251e-5 -0.0001633198640590998 0.0001001135737353429 -6.850899224426181e-5 -4.1947334084195146e-5 8.876162466569822e-5 -0.000107517278210964 -0.00012764483979461314 3.883270839672302e-5 -8.653635593853189e-5 -4.643104473851222e-5 -0.00010850840915715664 -3.068479925763881e-5 2.9752346036023917e-5 -0.0002946587861331691 4.480193676519764e-5 -4.414337816322609e-5 -7.726547626568984e-5 -5.913380009387873e-6 -2.8367881513752897e-5 -3.1246716660271494e-6 9.82471048724601e-6 7.319792932341836e-5 -2.633421861880467e-5 -3.189491833816841e-5 -7.434989091032696e-5 8.243199265073885e-5 -0.0001081667374635626 -0.00011648809011583627 5.755103340581673e-5 -6.77949701410191e-7 0.00010446410162350781; -0.00011711682689620625 3.16833054002051e-6 -8.722028813337178e-5 -5.797533684073344e-5 -0.00016206338566972692 -0.0001134515050404021 -6.914870116872154e-5 2.3115019775453296e-5 -3.3231860452906836e-5 -8.521182551759276e-5 -5.8022623289268215e-5 0.00010394512993970972 4.324643215970938e-5 -0.00010775809731113482 -4.2856282544213115e-5 -0.00017612519239320574 0.00020677770864439027 -8.837600851676185e-5 -0.00010319732335175543 -6.422697422801302e-5 -8.53779210780099e-5 -0.00013041657892296483 0.00011597706818010062 5.05674716072429e-5 -9.510859761632429e-5 3.5475396561544875e-5 -0.00011555170292976858 -0.0001650691710716418 3.617636959405415e-5 0.00010476988156313922 -0.00010511569506031761 0.00012489991697237746; -2.443810539884002e-5 -6.377699814660533e-5 -3.679525529773295e-5 1.3861453183522343e-5 5.622084513848518e-5 -1.4302303540572166e-5 -2.9986265843442908e-6 -0.00014427463524666076 9.961457631394622e-5 -2.7233520966934645e-6 9.952553314465637e-5 2.077047879826101e-5 0.00018157429526870783 -0.0001026680975399735 -0.00014838556585058916 -6.375039724557026e-5 -0.00017823189322993462 8.35856416898956e-5 -5.446816344087309e-5 -0.00021779312287146342 -1.565340339479275e-5 -3.364040007620069e-5 -7.499734664096492e-7 6.035030760646075e-5 -6.076322374323793e-5 2.4881320271702396e-5 -2.9062469408777813e-5 3.68750379067532e-5 -5.376994799630104e-5 -5.4373157624325534e-5 0.00012312884150849282 -0.00010927014140801764; 6.618465688660414e-5 4.863327178753616e-5 6.933452080119198e-5 3.7331541377341433e-5 -3.8222406303987756e-5 -0.0002249054961546913 3.092844806076636e-6 9.432611098162652e-5 -0.00012147238173639846 0.00010349517090361324 4.292122305486174e-5 1.5971685413374075e-5 0.00014997062665059275 3.6488836328486886e-5 4.560530926681862e-5 0.00016086199305135547 -0.00012663431714173614 7.90014162950562e-5 -0.00014690772529576184 -0.00012629109566916728 0.00011125257135716966 -0.00023642703324441874 -4.210714193760717e-5 8.448282444731869e-5 5.2535829879332494e-5 -0.00011458820533590573 5.415452308369325e-5 -0.0001535472694493477 -1.3775638920208138e-5 0.00013601551456957834 2.254065641607701e-5 0.00014901049128383174; -3.167162852304144e-5 0.0002082128814147963 -5.4595238754414807e-5 -0.00022239780431089996 -4.5097982212324534e-5 7.490133081620473e-5 2.8284725129482507e-5 -5.3857591257552456e-5 -8.718295778788722e-5 6.495124777603813e-5 -0.00010458621166588814 -3.5932092320637496e-6 7.196922903705259e-5 -6.012286378944377e-5 -7.25703683489755e-5 -0.00010184007608739738 0.00012514416100049315 0.00012656512641871263 -5.697784566262816e-5 0.00014042007766721637 -2.3507846723860068e-6 -1.592685278563898e-5 -3.358259330678817e-5 2.6494184720207945e-5 2.3003576796612873e-5 -4.1816459844709214e-5 0.00015351898416003112 -5.1504903087048645e-5 2.3804352320725245e-5 -0.00011271222668409427 9.49408255866136e-6 9.18491897830918e-5; -0.00010837879067778637 4.339556785041464e-5 0.00010760929166779066 6.92850766840595e-5 6.072140553940124e-5 -1.6832757459748577e-5 -6.0036640711154724e-5 0.00019400257173899746 9.362658187953518e-5 -5.8039332138401585e-5 -0.00010613493628127314 5.1574853792290995e-5 0.00017565008915468124 -1.610506711088102e-5 -5.6137877031255774e-5 2.931011608368383e-5 -5.404603194122067e-5 -1.2897760420425141e-5 4.627879702227321e-5 -6.873977368177257e-7 7.442625378232547e-5 -9.142024332917482e-5 9.872252434931686e-5 -6.323289976947165e-5 2.810876089413379e-5 7.30285204967688e-5 -1.0897781571229505e-5 -4.703832433235674e-5 5.8981760453639136e-5 -0.00013281170372877159 8.783328162702306e-6 -3.914667154688278e-5; 1.5708516893186998e-6 -2.799654321468774e-5 -9.498075342169816e-5 0.00010970777402500774 4.791241480954123e-5 9.717226675813213e-5 -2.254068822556487e-5 -9.561097433181352e-6 5.729982621952987e-5 0.0002600826021061199 -4.061676263707537e-5 -8.717479251647768e-6 -0.00018270858963849952 -0.0001328550813942296 2.3345907900800558e-5 7.565344749120211e-5 1.1033758463943406e-6 5.359403912545311e-5 -5.601777120508109e-5 -4.73939133390413e-5 -0.0001800884299901393 8.304204420520896e-5 1.547107567990206e-7 -1.6812252217353188e-5 7.956198278217731e-5 -0.00011437483299421926 -7.23355940882201e-5 -7.211098527666963e-5 9.066705133401936e-5 0.0001697076269901279 7.18431158001442e-5 -9.697431487628116e-5; -0.00013360289809826787 7.741849188858533e-5 8.312061621507267e-5 -8.769354095749981e-5 -6.228360504212033e-5 5.408698419740318e-5 1.7889755359990727e-5 -9.43679587759307e-5 -0.00013061723799511488 6.349525577787058e-5 2.1360052447906084e-5 3.0717932564935184e-5 2.4407968670849423e-6 2.6932430377029277e-6 -9.68070343911873e-5 -5.7464963868934324e-5 5.758101909703229e-5 -6.575461517388143e-5 7.703032682582394e-5 -0.00026648944546068284 0.0001376131076871203 -0.000136717459066492 0.00010722872324220736 -6.809127440975823e-5 8.309477928958383e-5 0.00020986044186099942 -0.00020950480026259635 -0.00020831617072292058 -3.453282985321608e-5 8.711689955487541e-5 3.944176147704839e-6 -7.313875172587414e-5; 0.00010100879606378022 -4.0807483146834434e-5 7.492173770495945e-5 -1.5791581247072039e-6 -0.0001068816111184938 -5.6368828354805064e-5 5.211014722397379e-6 0.00023620373556676677 -2.6462889054732122e-5 0.000205149613775379 0.00015309557946219468 -8.067772270826313e-5 0.00014763893139366376 3.536077363585988e-5 -3.011584723955191e-5 -5.661904126119577e-5 3.9481399453312614e-5 0.00012342624702776719 -5.590808925281987e-5 -1.5997577266031193e-5 0.00012828283231991477 -9.671253083296391e-6 -4.171387101470624e-5 2.843981319327308e-6 6.952630954382393e-5 9.334917720460675e-5 -5.853260356182138e-5 6.438270955167067e-5 1.4340559162396662e-6 -4.2410107398804424e-5 -2.3950795566854354e-5 2.3096803288158786e-5; -4.651235175183696e-5 -8.143413978027895e-5 -4.955552102397163e-5 -6.850017924658037e-5 5.301875027148221e-5 -0.0001483729134509784 -2.4995040427125342e-5 3.496005266574193e-5 8.734521919587152e-6 9.874724529585913e-5 -4.138183029349913e-5 6.215374845532922e-6 6.334786340963543e-5 -6.730872663534058e-5 -7.239008808606293e-5 -5.457516416422974e-5 -0.00018984926244807384 0.00013537505620992931 -4.320650585791979e-5 5.342481055200668e-5 -1.3482133532628904e-5 4.656109414432883e-5 0.00011088774082386159 -0.0001482749936134048 0.00015864288472006998 3.0135814462159654e-5 -8.451293214088666e-5 8.211541754794321e-5 1.8543083946174344e-5 -3.88865442676928e-5 -0.0001924273816133252 3.7236717993114894e-5; -0.00019912656874699507 -1.0917606902202106e-5 -9.623369418155595e-5 -6.760681002238992e-5 2.459078883797778e-5 7.931926508130462e-5 1.840963547105195e-5 0.00010437013248841978 0.00023274255205573704 -5.174176998158113e-5 4.5672534055064893e-5 2.4163277215432984e-5 1.955024369049304e-5 -0.00026943494752818795 8.186146284382583e-5 7.183863477943286e-5 -4.761945614196542e-5 -5.2410916727776245e-6 0.0001526926119036396 0.0001862957781336741 -0.00020666974318051855 4.54994717652458e-5 6.541891913324554e-5 -3.0409166116857747e-5 -0.0002166521096446203 -0.00014270902623726277 0.00015410919719537958 0.00011561113048041035 -0.00011685685761809542 2.651943601771565e-5 -0.00026733347634621987 -5.9421215825259404e-5; 0.00012010715500920138 4.826011855740235e-5 -8.316181855505898e-5 1.2211905879408162e-5 0.0002904740372393579 -5.345422185954506e-5 -3.0440800845357816e-5 -5.172636382511386e-5 -6.69265086692502e-5 3.053907631666667e-5 5.868103972872221e-5 7.827703949905096e-6 6.166733284016676e-5 -3.697126020441594e-5 3.3458918107239706e-5 0.00011953247804896321 -3.4879313250972154e-5 2.5764627491037188e-5 1.698293227647312e-5 2.989706946363591e-5 7.219089958354193e-5 -7.841148407197318e-5 -1.3625095096620392e-5 -8.087664399526188e-5 -0.00011663049598454053 0.0001790226792870411 -5.9549538935509815e-5 8.986834064827688e-5 -0.00011419791052347615 3.902911441140703e-6 9.221708524973742e-5 6.611704658254364e-5; -4.17419959752245e-6 -5.382770449990361e-5 8.820459789227543e-5 -2.5779680976608888e-5 -4.262851416036493e-5 -6.868384206803469e-5 -4.779090431305795e-5 6.600571669501978e-5 -0.00011848297708104936 2.021400183243012e-6 -4.985826580572148e-5 5.070996940582368e-5 7.051398830189133e-6 0.00011475004885675402 -3.096298950744033e-5 -0.00017869714337770027 2.838315322684451e-5 -4.212000475866702e-5 1.3873626818043246e-5 -2.691972349031816e-5 -2.4861817104589148e-5 -0.00014146593729866998 7.013595228150312e-6 -3.907420886583379e-5 -1.7842662595410217e-5 -9.151782705351909e-5 0.00013110715099707668 0.0002049571603546275 0.00013397965540740492 -7.018565792726458e-5 -0.0001382966611291719 -8.459679065175459e-5; 0.00010620047207239205 7.520522005406249e-5 -7.546623426706744e-5 -7.820351500514236e-5 5.4690184623626716e-5 -0.00018291495696858462 -0.00016115165097390597 3.654636195177864e-5 -6.923039662169786e-5 4.998698103854056e-7 -7.889912565688884e-5 0.00010629350246644953 1.3589086774038023e-5 -8.869662458992559e-5 -9.701145798417978e-5 -7.792487493234877e-5 -5.324689307700107e-5 2.5564177048085287e-5 9.55167178461961e-5 -8.607159004994748e-5 1.5008408735313991e-5 1.634389074653464e-5 -0.00020224938824058464 9.927772582059093e-5 -0.0001709247172412772 -4.120191268258067e-5 6.171475790224557e-5 -1.8287937638077427e-5 -6.474513255044099e-5 6.141088480844007e-5 3.6657440358713725e-5 0.0001588241411071907; -5.313533122529143e-5 3.570993564795792e-5 -0.0001071523752470412 5.9994328214143596e-5 -2.3542505485182086e-5 -7.981734463822936e-5 -7.268284619198267e-5 2.283121236547284e-5 -5.276334789226464e-5 4.742185180663146e-5 -0.0001054703557455827 0.00014138444882707006 -4.2546320881163846e-5 -1.7704693644134554e-5 8.389255296793711e-5 -3.1276943016508885e-5 -5.382673273555968e-5 -3.805973258562315e-5 -4.970003876945964e-5 6.54601512652084e-5 -0.00017162158704007977 6.929624280085933e-6 -4.482392821898906e-6 -0.00011541630604182144 3.388336833455513e-5 -0.0002655842495441351 -5.83011192501861e-5 -0.00010096809503733233 0.0002841720028955578 0.00022399328506440843 -1.8743476655735278e-5 0.00015314795516320495; -2.84464249120447e-5 1.656551379760368e-5 3.267905871375304e-5 -1.859109666789176e-5 -7.592934613008018e-5 0.0001326213323778111 -8.019508000821124e-5 3.048753482339204e-5 -0.00020826236046509288 9.213697288991837e-5 -0.00015984822581181538 0.00013850740746470632 0.00012299548783682216 1.7116214658516206e-5 -3.0268301629387064e-6 -9.490907446839053e-5 -4.738797266277628e-5 -6.488919367207144e-5 5.118856373933007e-5 -7.848313449306308e-5 -6.476671747755334e-5 -3.078538698848998e-5 0.00014186885622670986 -0.0001420135440270751 4.8324626769085695e-5 -5.7496729786593833e-5 -3.027466569569274e-5 -0.0001043213187217172 -0.00027794724788641714 8.457683092323964e-5 -3.544725674094792e-5 6.95732824538452e-5; 0.00012405280700532797 7.090681152531867e-5 -9.547339162908779e-5 -5.624994831977764e-5 -2.8616901763045134e-5 -0.00025871083220353356 0.0001273345840314552 7.742206777093572e-5 -0.00011473298967701887 -3.711906560001132e-5 -4.324055793349063e-5 -4.4479959105398656e-5 -5.0755567135067225e-5 7.786421316323751e-5 0.0001744045582750967 -0.00010292359387621545 -0.00019033359443655507 -0.00013120006791668063 -9.082660129878931e-5 5.531672264681236e-5 -5.650602564800729e-5 2.7744968757026144e-5 -8.560530863123016e-5 -5.910185813207837e-5 -7.076593486944594e-5 7.068206447057075e-5 0.00011613030052159498 -1.8702860319482447e-6 4.700223849850421e-5 -0.00014712238065785023 4.535951357526947e-6 6.596815343128295e-5; -2.1671352540373213e-5 -0.000138066927418344 0.00015222757055400758 -0.00017829271392454568 -0.00013861960915871813 -8.871876353421089e-5 5.3504161671500495e-5 -2.0493569634507412e-5 -0.00014196618586335673 7.499928314448337e-5 -0.00014883065897063538 5.490172854688081e-6 -0.000137342154728463 0.0001446586246786185 -0.00013485379177634744 0.00011863680255575947 -3.338953693653553e-5 -4.252784502366439e-5 -6.355605738260873e-5 -5.592617172419856e-5 -6.1539895866883904e-6 5.9457419313046724e-5 0.00010971715518483126 3.0747523753442494e-5 8.01448913009351e-5 -0.0001212475473110116 -0.00011530376660427335 0.00017042577373306133 -0.00018529390227538643 0.00015155030986736476 -4.775209900595344e-5 3.68634754451918e-5; -1.2603738545598273e-5 -3.5471655868909435e-5 8.926065303572867e-5 3.5630086054572684e-5 -8.335221544651665e-5 0.00011051663560980643 -6.562777169559745e-5 0.00010391297647917064 7.239797370463686e-5 -7.732956335924988e-5 6.62630317635031e-5 -3.57217109541413e-7 -4.5294289597512804e-5 9.174299149496158e-5 2.1573579866219482e-5 -0.00012630977479087687 -1.343563515941883e-5 7.605607781048541e-5 2.9381711934241146e-5 -0.0002371199317024704 9.138186138137445e-6 -5.230137874997636e-5 3.4481537710829707e-6 -2.836378560853708e-5 4.095542127803304e-5 2.629702422897226e-5 -0.00012287783655149712 -7.605343315329941e-5 -7.490012838783261e-5 -0.00010468832086584069 -0.00016263042387829017 8.328829531513632e-6; 0.0001745861297219001 -6.526631507073125e-5 3.901704592648583e-5 -9.068692927356136e-6 3.0547969506769405e-5 1.6156563722360244e-5 -5.4535085221102994e-5 0.00019610900897318247 0.00011681420408207051 8.06379318215844e-5 -4.002176167192187e-5 9.792495227669479e-5 0.00010737542441131521 -5.122717295522366e-5 0.00020383434428153274 8.370565017470758e-5 7.519557746164746e-6 -0.0001372080517488581 7.721029367316183e-5 -0.00014382553519897334 0.0001121139063594825 -8.811689651554757e-5 2.6986335503933265e-5 -5.660604831279075e-5 -0.00020003571291668575 -2.5202009348830013e-5 0.00023427722742565895 -9.029614768438334e-5 3.387640724908249e-5 4.465909267095491e-5 -0.000129387852505147 -8.32076914979181e-5; 0.00015996338748963004 -8.676854798678745e-5 -9.35739666302552e-6 7.602369856912629e-5 -3.4985488525045766e-5 -0.00011582696191609348 6.336105121890097e-5 0.00024041433476967736 0.00018235153817229856 1.9759906138724085e-5 -2.2258076903270624e-6 9.247557103010274e-6 0.00010147240806260363 2.6787036916437415e-5 6.21384284052433e-5 -9.07774332851743e-5 1.1419699661283505e-5 0.00022877539282789065 -9.477942825170361e-5 -2.8184205290689223e-5 3.011781933609193e-5 4.886056609689967e-5 -0.00014530859647194372 6.230133709621758e-5 -1.7598008923148536e-5 0.00010126321700526384 -8.630042014985501e-5 6.963106403717706e-5 -0.00011904362639753245 3.845061706969342e-5 0.00015081224097526956 -6.0727484594017166e-5; 2.5018666145358e-5 -4.308773137539314e-5 3.208943406448433e-6 -1.0705791884661387e-5 -7.118466752375196e-5 1.7474602226017368e-5 -0.00013198032107696714 9.63095065692549e-5 0.00014967396267663252 2.061623157557345e-5 7.940296776565884e-5 -4.464269810499155e-5 -5.8995397613876286e-5 8.903804549687837e-5 5.745637312271814e-5 0.00017483579571332057 -2.5302342384506793e-5 -1.4903305474703051e-5 -6.557961902375129e-5 0.00015927925991334211 -0.00011407334537737772 -8.931843482216557e-5 -8.774492167661584e-5 -3.820626230820581e-5 -0.00023867422138367626 5.981595890104006e-5 -0.00013119475048527522 0.00011898147780654206 2.6323891086817807e-5 -0.0001758117374827039 -6.600268685517757e-5 6.313082698240167e-5; -0.00015309844949617496 0.00014270044137854637 -3.269321079234942e-5 -1.303438162285444e-6 0.00014535392491277834 -0.00012739495051893832 -3.271962251848868e-5 3.3699058425991276e-5 0.00020791516961660633 5.18504517317339e-5 2.027446120773861e-6 6.42978435139213e-5 6.0508739429200673e-5 -0.0001977357073115133 0.00010809905014364637 -1.4077240316983159e-5 -0.00017357673406469991 -7.176623890264351e-6 -7.622175091048131e-5 0.00014064210207331937 6.150040515829434e-5 -7.826222046380893e-5 1.8840366996736677e-5 0.00011093435985865507 4.835438683992857e-5 2.565974917568398e-5 -0.00014640855571589992 -2.8351255801244356e-5 -0.00016531418570186365 -0.0002510540120197403 0.00020287300374955024 0.00022438933463898118], bias = [-3.875160155308591e-9, 4.837728125630538e-9, 5.487345552385669e-10, -1.8898824909877184e-9, -8.561617583015937e-10, -2.5379776534976857e-9, 3.7879106219481923e-10, 2.0052039558145632e-9, -1.0641943150399122e-9, -3.8067626357289625e-9, -4.628400908882598e-9, -2.3056116741189397e-9, 8.293019963162328e-10, 6.199872044933642e-11, 1.3128060121987606e-9, 1.793504614885018e-10, -2.230047923358518e-9, 3.3485592102048677e-9, -1.6439902133647748e-9, -1.3185276301630261e-9, 2.0641088809846144e-9, -1.7489529795444279e-9, -2.228335787729176e-9, -1.1737283234816013e-9, -2.3782917696229954e-9, -2.4594914019936064e-9, -2.3866121212140036e-9, -2.031887868802623e-9, 1.9132836698239018e-9, 3.0163967425278544e-9, -1.041973278247232e-9, 6.353588835327854e-10]), layer_4 = (weight = [-0.0006718122727404353 -0.0005405242948348481 -0.0005170277137384525 -0.0007445608726162324 -0.0007542717606932963 -0.0004390426914089011 -0.0005251258288071163 -0.0006611963753687249 -0.0005780746019681752 -0.0007948348692818622 -0.0007682056750199673 -0.0008190024171062159 -0.0007634174873826743 -0.0008002522937750759 -0.0007082770867892582 -0.0006554343357925255 -0.000899241576989253 -0.0006136442030680258 -0.0007325214074138757 -0.0006935382659347849 -0.0007110252107006617 -0.0006248878287347058 -0.0007384671423282585 -0.0007354420370588787 -0.0007821958668179144 -0.0005950324865631218 -0.0007780533336175416 -0.0007445712494258412 -0.000804707731827364 -0.0007136304588015345 -0.0006219217205597736 -0.0006734830904427547; 0.0004439705938988827 8.218200724546903e-5 0.00013131641922928838 0.00029136407647359163 0.0001444162899665492 0.00023007931042921703 0.00022953947068506922 0.00041235443017090414 0.0002327135309403386 0.00011465792530309854 0.000201050970287389 0.0002260725681416791 0.0002723481032535314 5.687330497776777e-5 0.00014336545977348768 0.0002804549148118075 0.0004570622393776388 0.00023999878822950218 0.0001646352627818645 9.521140425179368e-5 0.00025816169795630036 0.0002928668504012495 0.0002389516344688959 7.651840938203167e-5 0.00016641027509106982 0.0002139566187439731 0.00031489796250299556 0.0002945038443451502 0.0003332403761374526 0.0003641792568650957 0.0002313279523589222 0.00023717958252961877], bias = [-0.0006878105543907823, 0.00022348689905386573]))
Visualizing the Results
Let us now plot the loss over time
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")
lines!(ax, losses; linewidth=4, alpha=0.75)
scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)
fig
end
Finally let us visualize the results
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(compute_waveform(
dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)
l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
s3 = scatter!(ax, tsteps, waveform_nn_trained; marker=:circle,
alpha=0.5, strokewidth=2, markersize=12)
axislegend(ax, [[l1, s1], [l2, s2], [l3, s3]],
["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
position=:lb)
fig
end
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 128 × AMD EPYC 7502 32-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
JULIA_CPU_THREADS = 16
JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 16
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.