Training a Neural ODE to Model Gravitational Waveforms
This code is adapted from Astroinformatics/ScientificMachineLearning
The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl
Package Imports
using Lux, ComponentArrays, LineSearches, OrdinaryDiffEq, Optimization, OptimizationOptimJL,
Printf, Random, SciMLSensitivity
using CairoMakieDefine some Utility Functions
Tip
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector
function one2two(path, m₁, m₂)
M = m₁ + m₂
r₁ = m₂ / M .* path
r₂ = -m₁ / M .* path
return r₁, r₂
endone2two (generic function with 1 method)Next we define a function to perform the change of variables:
@views function soln2orbit(soln, model_params=nothing)
@assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
if size(soln, 1) == 2
χ = soln[1, :]
ϕ = soln[2, :]
@assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2"
p, M, e = model_params
else
χ = soln[1, :]
ϕ = soln[2, :]
p = soln[3, :]
e = soln[4, :]
end
r = p ./ (1 .+ e .* cos.(χ))
x = r .* cos.(ϕ)
y = r .* sin.(ϕ)
orbit = vcat(x', y')
return orbit
endsoln2orbit (generic function with 2 methods)This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d_dt(v::AbstractVector, dt)
a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
b = (v[3:end] .- v[1:(end - 2)]) / 2
c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
return [a; b; c] / dt
endd_dt (generic function with 1 method)This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d2_dt2(v::AbstractVector, dt)
a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
return [a; b; c] / (dt^2)
endd2_dt2 (generic function with 1 method)Now we define a function to compute the trace-free moment tensor from the orbit
function orbit2tensor(orbit, component, mass=1)
x = orbit[1, :]
y = orbit[2, :]
Ixx = x .^ 2
Iyy = y .^ 2
Ixy = x .* y
trace = Ixx .+ Iyy
if component[1] == 1 && component[2] == 1
tmp = Ixx .- trace ./ 3
elseif component[1] == 2 && component[2] == 2
tmp = Iyy .- trace ./ 3
else
tmp = Ixy
end
return mass .* tmp
end
function h_22_quadrupole_components(dt, orbit, component, mass=1)
mtensor = orbit2tensor(orbit, component, mass)
mtensor_ddot = d2_dt2(mtensor, dt)
return 2 * mtensor_ddot
end
function h_22_quadrupole(dt, orbit, mass=1)
h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
return h11, h12, h22
end
function h_22_strain_one_body(dt::T, orbit) where {T}
h11, h12, h22 = h_22_quadrupole(dt, orbit)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
h11 = h11_1 + h11_2
h12 = h12_1 + h12_2
h22 = h22_1 + h22_2
return h11, h12, h22
end
function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
# compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2
@assert abs(mass1 + mass2 - 1.0)<1e-12 "Masses do not sum to unity"
h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
@assert mass_ratio≤1 "mass_ratio must be <= 1"
@assert mass_ratio≥0 "mass_ratio must be non-negative"
orbit = soln2orbit(soln, model_params)
if mass_ratio > 0
m₂ = inv(T(1) + mass_ratio)
m₁ = mass_ratio * m₂
orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
return waveform
endcompute_waveform (generic function with 2 methods)Simulating the True Model
RelativisticOrbitModel defines system of odes which describes motion of point like particle in schwarzschild background, uses
where,
function RelativisticOrbitModel(u, (p, M, e), t)
χ, ϕ = u
numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
denom = sqrt((p - 2)^2 - 4 * e^2)
χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = numer / (M * (p^(3 / 2)) * denom)
return [χ̇, ϕ̇]
end
mass_ratio = 0.0 # test particle
u0 = Float64[π, 0.0] # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4) # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize) # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, eLet's simulate the true model and plot the results using OrdinaryDiffEq.jl
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)
axislegend(ax, [[l, s]], ["Waveform Data"])
fig
endDefiing a Neural Network Model
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,
const nn = Chain(Base.Fix1(broadcast, cos),
Dense(1 => 32, cos; init_weight=truncated_normal(; std=1e-4)),
Dense(32 => 32, cos; init_weight=truncated_normal(; std=1e-4)),
Dense(32 => 2; init_weight=truncated_normal(; std=1e-4)))
ps, st = Lux.setup(Xoshiro(), nn)((layer_1 = NamedTuple(), layer_2 = (weight = Float32[7.1311515f-5; -0.00013078368; -0.0002239552; 8.934998f-5; -0.00020378639; 3.0598425f-5; -4.6988036f-5; -3.420913f-5; -1.337678f-6; -1.1305797f-5; 3.82717f-5; 0.000107510445; -4.5251767f-5; -2.0133122f-5; -0.00021623693; 2.0941157f-5; 8.669663f-5; 8.986749f-5; -7.1266404f-5; 0.000100689416; 1.9042327f-5; 0.00017213565; -2.9612556f-5; -7.3274095f-5; -0.00016167386; -0.00015032192; -1.0876728f-5; 2.9852f-6; -0.000137655; 8.667418f-5; 0.00011985661; 3.706418f-5;;], bias = Float32[-0.3935684, -0.71884716, 0.114752054, 0.89717245, 0.8154042, 0.69903016, -0.23833692, 0.33024538, 0.67412925, 0.76414967, 0.08590925, -0.2749989, -0.31898785, 0.88698494, -0.8884337, -0.067107916, 0.44191396, -0.9261888, -0.09410691, -0.6152959, 0.35559273, 0.13374674, -0.95089066, 0.8691733, 0.29971325, -0.25134587, -0.5746331, 0.9365338, 0.28028142, 0.90126634, -0.051383853, 0.010474324]), layer_3 = (weight = Float32[5.15636f-5 0.00019067976 3.888606f-5 -1.7444561f-5 0.00012710984 1.3399639f-6 -8.401184f-5 -5.232881f-5 0.00012514342 -0.00010558184 0.00017082332 5.271163f-5 8.062669f-5 4.7760866f-5 -0.000119667144 -7.860972f-5 6.461603f-5 -7.092315f-5 -2.354087f-5 4.977906f-5 9.8177916f-5 0.00014749126 7.277514f-7 -0.00017526577 8.114703f-6 -4.2027597f-5 3.1595555f-5 7.090783f-5 -0.00018504492 -0.00016893695 -1.5952046f-5 0.00014143858; -2.2755407f-6 6.096118f-5 -1.4450851f-5 9.357666f-6 2.1938245f-6 2.8449256f-5 -0.00013796848 6.290334f-5 5.128674f-5 4.4559725f-5 0.00010166927 8.7261615f-5 -0.00013558505 0.00016659846 8.248985f-7 -1.0119376f-5 -0.00022594989 2.1480777f-5 -0.00012437269 -0.00019795273 5.4754353f-5 8.063706f-5 -9.091036f-5 0.00017693883 -6.437948f-5 0.00010441896 1.4789692f-5 -4.4743876f-5 -9.914928f-5 3.801028f-5 3.8325714f-5 -3.0435906f-5; -8.5435706f-5 2.6426153f-5 -0.00010575642 -2.3693243f-5 -0.00018920552 -6.982803f-5 3.4295397f-5 7.599107f-5 -2.4642679f-5 -4.504576f-5 -5.9182858f-5 5.2414143f-5 0.00013944115 -2.8318242f-5 -4.2442312f-5 0.00016759275 0.000177478 -0.00018287104 -0.00011570445 -6.389226f-5 0.00013647739 0.00015706495 5.4785905f-5 -8.7962675f-5 5.310344f-5 3.2038493f-5 -3.8827646f-5 0.00012415677 3.6626578f-5 1.2301277f-5 -0.00011112622 -2.9500992f-5; -7.2534116f-5 8.789086f-5 0.00011558876 -5.6391582f-5 -0.0001546499 2.5663561f-5 -0.00011515457 -0.00013969629 4.5208122f-5 4.047303f-5 -0.00012583767 0.00012813754 2.0328704f-5 -3.1560503f-5 -1.6406897f-5 -7.862273f-5 0.00013978305 6.497126f-5 2.9592855f-5 3.997813f-5 -5.7142952f-5 -3.5775352f-5 -7.0095746f-5 -7.170436f-5 -2.352161f-5 -2.9568153f-6 -0.00017784553 -0.00021148323 9.79221f-5 -4.573704f-5 9.45032f-6 1.0432254f-5; -8.42513f-5 -9.0182606f-5 0.00016580193 -0.00022687232 0.00013576588 -2.3230286f-5 4.583429f-5 0.00010316733 4.27107f-5 5.614725f-5 0.00011300599 -3.7165864f-5 -0.00014973876 -1.8655988f-5 -6.741919f-5 -6.39985f-5 8.252401f-5 3.6439833f-5 -2.6338561f-5 -1.2701409f-5 -0.00010382018 0.00023030484 4.02651f-8 3.5787848f-6 -4.7297f-5 -9.0301175f-5 -0.00010830732 4.8572524f-6 -2.2438813f-5 0.00019308683 3.1326108f-5 -6.1383557f-6; 0.00017715721 6.444965f-5 -2.7973558f-5 0.00013668292 -0.00019245083 7.854711f-5 -0.000115219016 -7.829027f-5 -6.0736267f-5 8.497275f-5 -0.00016530375 -5.708709f-5 8.0759695f-5 -0.00010247963 4.1885472f-5 -0.00020741919 -0.00021628915 -7.6118445f-6 0.0001542263 0.0001289092 -6.859296f-5 -8.116227f-5 -6.7609406f-5 -3.065989f-5 -0.00014312427 -0.00012147762 -8.487813f-5 -0.00020909196 -1.542614f-5 -4.3862285f-5 0.00016672314 -0.00013333757; -1.7027012f-5 -7.8386605f-5 -2.7537635f-6 -3.426607f-5 -0.0001093072 -1.0140402f-5 0.00021774105 -3.442178f-5 6.533635f-5 -5.255871f-5 4.4149296f-5 -9.469507f-5 -9.378705f-5 -0.00026852934 -4.6201007f-5 -1.8231338f-5 -7.6434684f-5 5.9508955f-5 -5.1013903f-5 6.0190618f-5 -0.00033646112 0.00013251539 -0.000118835465 -7.5330645f-5 6.577985f-5 0.00016976857 0.0001479773 5.4768443f-5 0.00022539716 1.7081447f-5 4.7059195f-5 -4.2382737f-5; 0.0002534501 3.16694f-5 0.00014633496 3.902444f-5 2.5036346f-5 2.560169f-6 -6.528801f-5 5.2367202f-5 0.00020261326 -4.807289f-5 6.290342f-6 -0.00017363098 -0.00010943004 -5.3951517f-6 6.216693f-5 2.3805858f-5 0.00028443852 -6.578317f-5 -5.2412805f-5 -9.7488315f-5 -1.3441609f-5 -0.00010455515 0.00022673697 8.535222f-5 1.9609544f-5 2.0448776f-6 -0.000101947124 -0.00013655076 -0.00013482767 -4.6284906f-5 -3.927193f-5 -4.0323357f-5; -0.0001029424 8.638997f-5 0.00016807053 -0.00020801311 0.00015467464 -0.00014960136 -3.733771f-5 -0.00018042866 -8.0336715f-5 2.0403142f-5 3.1480408f-6 -7.403209f-6 -0.000102960264 0.00015524948 0.000161676 -0.000100118064 9.5990064f-5 0.00018563909 7.097602f-5 -2.5771214f-5 7.736138f-5 9.409067f-5 -7.980458f-5 -1.3846307f-5 -4.5093922f-5 -5.892269f-5 -7.933536f-5 6.475036f-5 -8.791489f-5 -3.9298928f-5 0.00014953945 7.421399f-5; 0.00026504372 -7.692821f-6 -0.00018104819 1.7321714f-5 3.8842204f-6 0.00015996047 -3.724328f-5 -4.1209503f-5 0.00016032571 -6.3076477f-6 -5.2268657f-5 8.155229f-5 -4.19463f-5 -5.2257987f-5 4.1209005f-5 -7.8016936f-5 -3.5511646f-5 0.00014709977 -7.315321f-5 -0.00012653519 2.2044016f-5 -0.00026529454 -6.875417f-5 0.0002733139 -0.0001666699 1.8141718f-5 -0.00012184366 5.904941f-5 0.00019336976 -7.755622f-5 6.685211f-5 5.414752f-6; -0.00011702737 -9.039789f-5 7.9599835f-5 0.0001097231 6.0403236f-5 7.368714f-5 -4.9795082f-5 -7.8271034f-5 7.88244f-6 -5.489562f-5 -6.127763f-5 5.420422f-6 0.00012381485 0.00020219805 0.00016511863 3.38734f-5 7.349872f-5 -8.0285405f-5 0.00019781553 6.586289f-5 5.5728407f-5 -0.00010337465 7.301946f-5 4.8570404f-5 -0.00019983022 -7.2081726f-5 -7.724407f-5 7.878362f-6 7.4036085f-5 -1.7159666f-5 -0.00015000679 4.6705587f-5; 1.2934281f-5 -6.45691f-6 -2.4318186f-5 -0.00010313168 2.71905f-5 0.00012322997 3.379096f-5 6.5755055f-5 -8.253915f-5 -0.0002074419 -2.5740866f-5 -9.228922f-5 0.00016494526 8.333627f-5 6.481775f-5 -9.0779475f-5 0.00016111431 -0.00015401571 -5.092591f-6 5.7129706f-5 0.0001400248 -1.1343751f-5 -4.3086107f-5 3.688193f-5 0.00012544241 5.3423177f-5 3.193531f-5 -6.195374f-5 -0.00011256676 -1.0741084f-5 -0.00015544378 -0.00016407494; 1.7903052f-5 7.5993444f-7 -8.874757f-9 -9.8887605f-5 1.1979431f-7 3.513754f-5 -0.00012919633 2.5748139f-5 -2.9587332f-5 1.8644905f-5 -5.426895f-5 0.00014281861 -6.1257124f-5 -3.950947f-5 9.8790064f-5 -1.3973963f-5 -0.00016492867 -1.6984584f-5 -7.545234f-5 0.00015263255 -0.0001380934 -5.8184814f-5 4.0253726f-5 -5.2732077f-5 -4.274571f-6 -0.00012813113 6.091037f-5 0.0002514176 -6.452832f-5 -3.223115f-5 -0.00013353708 -0.00024113775; 0.00019302857 0.00011559154 9.2539194f-5 3.297726f-5 2.8544388f-5 3.8123217f-5 8.0666374f-5 7.3173064f-6 -4.4656776f-5 -0.0002006777 1.5995916f-5 -7.4128555f-5 5.1590952f-5 -6.0075936f-6 -0.00013019147 -0.0001638487 -1.4430497f-6 -0.00021935337 -0.00010655699 0.00021270127 0.00010892279 0.0001024269 -4.5405024f-5 -3.9392708f-5 3.112285f-5 -3.246212f-5 -0.0001262576 3.8853617f-5 0.00015670087 -2.587831f-5 -8.9502864f-5 9.295036f-5; 4.8602335f-5 1.2288691f-5 0.00012927591 -2.3040704f-5 4.5472745f-5 2.0581916f-5 -7.785265f-5 -3.4563283f-5 -8.64245f-5 -3.453609f-6 -1.564373f-5 -1.710772f-5 -8.664542f-5 -2.8719409f-5 0.0001002334 2.1782078f-6 -3.7077618f-5 8.696902f-6 0.00013740038 -0.00016546922 8.786545f-6 -1.702692f-5 3.256142f-5 -6.720823f-5 0.0001826004 1.977019f-5 9.498217f-5 0.00020699791 2.2627935f-5 0.0002162803 -0.00019490831 -6.4379055f-5; -6.341886f-5 7.524494f-5 -3.1900843f-5 5.2269887f-5 5.9549715f-5 4.3701733f-5 -9.535151f-5 8.666117f-5 -5.9689323f-6 0.0001635086 0.0002326957 9.531038f-5 -4.9654017f-7 0.000113753085 -8.599619f-6 0.0002140867 -0.00019001636 -1.3367693f-5 -0.0001272221 6.84729f-5 -0.00016786226 5.0299474f-5 0.00013204518 0.0001456368 -4.837989f-5 -7.4342606f-5 6.757259f-5 2.921058f-5 0.00019708004 -7.945372f-6 -0.00015771715 -9.159193f-5; -1.5075297f-5 0.00015351584 -0.00017725473 -6.7183566f-5 -0.00017399278 2.5837082f-6 0.00015861783 0.00013679489 9.605943f-5 4.3188215f-6 0.00012708888 0.0001592576 -2.4016283f-5 -2.5003661f-5 -0.00022581941 1.8239924f-5 6.756676f-5 0.0001692219 2.3851082f-5 -5.103281f-5 0.00010266097 -0.0002163119 -0.0001281896 -1.1127934f-5 6.346975f-6 0.0002293305 -8.441722f-6 -5.7808745f-7 -0.00011231701 -8.0321646f-5 -9.255215f-5 9.290669f-5; -0.00013687152 -5.9657104f-5 7.443503f-5 -5.1255774f-6 7.8980585f-5 0.00019504306 -2.745266f-5 0.00014065891 -0.00013338319 5.6108758f-5 -5.9499253f-5 -3.7165144f-5 -3.0664913f-5 0.00020231938 -5.2090298f-5 9.541358f-5 4.455624f-5 -4.7847676f-5 -0.00010344691 -0.00012423313 0.00019607277 -8.475692f-5 -3.5450816f-5 2.385414f-6 -0.00019006897 2.9169092f-5 0.00015010752 0.00020149961 6.121339f-5 1.1145612f-5 9.872864f-5 5.0263392f-5; 1.2660383f-5 4.3362263f-5 -2.608302f-5 7.2665025f-5 8.630564f-5 8.0616795f-5 6.8477886f-5 -8.4999076f-5 -0.00010012563 -0.00015041477 7.534838f-5 -3.0457142f-5 -7.383426f-5 -7.8291574f-5 0.00020630266 0.00010878698 -4.6203848f-5 2.9062503f-5 -1.3840483f-5 -0.0002347614 -0.00010321206 2.9436222f-5 9.178899f-5 -0.00018288233 -0.00012565221 2.6628792f-5 -7.8703226f-5 -3.32963f-5 -2.1203356f-5 -9.568237f-5 0.00013806378 6.0717916f-6; -9.467781f-5 5.493053f-6 3.1751263f-6 -3.0703188f-5 1.0468203f-5 6.947027f-5 -4.2314237f-5 0.00018103779 8.5488085f-5 0.0001635316 0.00013080153 1.7636488f-5 -0.0001112317 -8.080309f-5 -6.342014f-5 5.9289538f-5 -3.0686166f-5 -0.000144462 -8.117519f-6 6.1079205f-5 -0.00015054239 2.8523322f-5 0.0001377611 -0.00023868632 -0.00012314986 -3.7609072f-5 0.00011250597 0.00014531032 -9.960775f-6 0.00018910352 7.195069f-5 -0.00010065512; -0.000109625544 8.291354f-5 -0.00021776327 4.668174f-6 0.00018954645 -9.7917655f-5 7.895516f-5 8.536226f-5 0.0001111033 0.00012978638 -7.130381f-5 -7.316371f-5 9.594659f-5 8.964981f-6 6.6840395f-5 -0.000104425984 -5.329817f-5 9.9353165f-5 -3.048356f-5 5.1004754f-5 -4.1833344f-5 9.982553f-5 -6.407971f-5 -0.00010264214 1.8908524f-5 0.00012266771 -4.7419766f-5 0.00020690035 -0.000106215986 2.4675746f-5 0.00013285299 -0.00019623646; -8.143129f-5 0.000119260636 -4.7851532f-5 0.00019884006 -0.00012450723 5.4643457f-5 4.3390093f-5 7.090103f-5 4.252552f-5 -2.734812f-6 -3.206872f-5 -0.00017025585 5.9238013f-5 -4.0854866f-5 7.689075f-6 -0.00014679901 6.8385307f-6 9.353525f-6 0.00010305149 0.00011149378 -8.458012f-7 -9.85964f-5 7.6371114f-5 -6.7937845f-6 -5.4873926f-5 8.346189f-5 0.00018379785 0.000114317925 7.142563f-5 -2.81204f-6 -8.54917f-5 0.00014911333; 6.800089f-5 -6.1387116f-5 7.4158095f-5 -7.510134f-5 0.000109102395 0.00015358083 -3.2839904f-5 8.24039f-5 -0.00017121823 -0.00012608636 9.689881f-6 1.5132873f-5 -7.2982366f-6 4.6796587f-5 -2.7416607f-5 0.0001662258 -1.5386422f-5 1.1667393f-5 -6.8538124f-5 5.3856165f-5 -2.52642f-5 2.1459326f-5 -0.00012681745 -2.0759593f-5 -0.00016555797 0.000112471294 -0.00020873192 4.731086f-5 -7.806009f-5 1.6923896f-5 -8.666857f-5 -6.354583f-5; 1.6055905f-5 -9.214781f-5 3.9308037f-5 3.3392822f-5 -2.206425f-5 -0.00011164271 -0.00011054727 -2.9351091f-5 0.0002630095 -2.4370649f-5 -6.7568464f-5 -1.3051281f-5 1.2546349f-5 -0.00013098583 0.00022125691 -0.000245611 -0.00013754221 -6.024432f-6 0.00010165594 -0.000103474915 8.120506f-5 -0.00011150022 -5.101117f-6 5.1772363f-6 -4.568732f-5 5.2508403f-5 7.4241684f-6 5.5018543f-5 -6.6729575f-5 1.6823935f-5 -5.916389f-7 0.00011691222; 0.0001646864 -4.142229f-6 1.6305994f-5 4.338502f-6 5.6288725f-5 -5.3282856f-5 9.9316116f-5 7.054002f-5 0.00014977036 9.62237f-5 -0.00013130985 -8.7790235f-5 4.9692022f-5 -0.00022217307 1.8167892f-5 0.00010693777 -0.00012857521 -2.3393142f-5 9.306996f-5 -6.0381168f-5 -8.3237974f-5 4.5449928f-5 -1.695104f-5 -2.9586972f-5 0.00020564764 -4.573562f-5 3.0837346f-6 -4.6060144f-5 -4.042463f-5 -0.000104624516 2.6280903f-5 -2.8668268f-5; 1.9805097f-5 1.4936951f-5 -1.6918091f-5 -5.8857604f-5 -9.0125744f-5 -0.00017295212 5.3509066f-5 7.9406505f-5 -7.841361f-5 -3.4586632f-5 -0.00023340428 -0.00010877198 -0.00011123692 0.00016576775 0.00015244642 0.00023112883 -0.00013181641 -3.6525205f-6 0.00019814311 -0.00020394394 0.00018931508 7.898528f-5 -6.521945f-5 7.485369f-5 -0.00014172569 -6.618673f-5 2.276073f-5 0.00022425594 -0.00022733075 5.9990063f-5 -4.4765293f-5 0.00015490137; -0.00018362187 2.8344835f-5 5.8756792f-5 -9.042312f-6 0.00015836422 -0.00012203657 -0.00013775907 -8.766842f-5 7.999336f-5 -0.0001546485 -6.2003282f-6 5.836959f-5 7.971577f-5 -4.6827554f-5 7.3622585f-5 -0.00012464872 0.00010568612 -2.3056837f-6 2.071972f-5 0.00012883238 1.4306578f-5 -0.000103827224 -0.00011536503 -7.4774754f-5 -2.8425611f-6 -4.9637896f-5 -2.0438847f-7 -5.816786f-5 1.1013882f-5 0.00013542427 -8.206035f-5 -0.00010210005; -0.00011804372 4.8793205f-5 0.00021875354 0.0002522998 -4.4419536f-5 3.124824f-5 -9.218025f-5 -1.9714538f-5 -5.2572792f-5 4.449607f-5 3.879932f-5 4.9517024f-5 -1.3214128f-5 -0.00012538853 9.526499f-6 1.1003464f-5 0.00013922596 0.0001321466 2.3089328f-7 0.00012948242 -8.0395395f-5 0.00018970041 1.456222f-5 0.00011223969 5.3732947f-5 0.00027434397 -7.338939f-5 0.00011717035 1.5658159f-6 0.00018022941 4.733517f-5 -5.2879903f-5; -6.222677f-5 0.00027757348 3.4437726f-6 -1.4459124f-5 -4.2133192f-6 9.586807f-6 5.8626945f-5 6.916484f-5 1.9328612f-5 4.818579f-5 0.00013721797 -0.00018900583 0.00010223389 -4.0830466f-5 2.2576074f-5 -0.00013741024 -6.526348f-5 -5.948754f-5 9.239704f-5 0.000116690055 -4.361864f-5 5.606485f-5 3.9153245f-5 0.00013993436 -6.717593f-5 -5.53134f-5 6.6579545f-5 -5.6652727f-5 7.904966f-6 -0.00018478406 0.00017564367 -1.0333412f-5; -0.00010444986 -0.000102363476 6.608463f-5 0.00011314458 0.00013857914 -0.000118285105 4.603849f-5 -4.089863f-5 -8.1358856f-5 -1.8741935f-5 -5.3909112f-5 -8.040258f-7 -4.9120696f-5 7.6762815f-5 -0.00012347965 -7.100365f-5 8.7470486f-5 -2.9024992f-5 -0.00010388234 0.00022261453 8.514171f-5 0.00019381267 8.092069f-5 4.3974283f-5 2.363447f-5 -3.0861098f-5 2.1804104f-5 -1.1330622f-5 -2.9188048f-5 8.9880734f-5 5.233489f-6 4.098313f-5; -4.9898626f-5 -6.455293f-5 2.4219906f-5 -0.00010985319 -6.856127f-5 -0.00010198282 -7.982736f-5 -5.4867396f-5 6.77942f-5 -0.00017294238 -8.947466f-5 4.3232783f-5 -8.979546f-6 0.0002048518 -0.00026993835 3.590198f-5 -7.731442f-5 -5.6201698f-6 -3.9574155f-5 -9.648026f-5 0.0001457238 -4.0256207f-5 -9.298346f-5 2.9251334f-5 -8.965071f-5 -0.00015286378 -0.0001119524 0.000100935154 8.729911f-6 9.716435f-5 7.367787f-5 -7.197131f-5; 0.00013874641 7.027043f-5 -4.9268456f-5 5.4073564f-5 5.181305f-5 -8.9737376f-5 -9.7730626f-5 -1.0230601f-5 8.339737f-6 -2.2131924f-5 7.285369f-5 1.0869996f-5 -1.2246421f-5 -5.9067464f-5 0.00014589881 0.00017806483 -9.75946f-6 -2.518488f-5 5.834677f-5 0.0001633923 -4.9666975f-5 -5.3861273f-5 0.00011604426 -0.00016577377 -0.0001484102 -3.7265425f-5 -3.026625f-5 9.5837f-5 3.1177653f-5 2.5054213f-5 7.0854716f-5 -0.00010311305], bias = Float32[0.14560096, 0.1495481, -0.0068453224, 0.1369361, -0.12243223, 0.0642342, -0.09587835, 0.13855602, -0.1441696, -0.12956598, 0.017134612, 0.0324556, 0.1677765, -0.006135001, 0.012148766, 0.14453286, -0.16991735, 0.05646685, 0.11063891, -0.0802336, 0.07706256, 0.005536094, 0.012314255, 0.021647317, 0.01715478, 0.0150322635, 0.14215364, -0.112184286, -0.056506597, -0.06697364, -0.12680033, -0.10132648]), layer_4 = (weight = Float32[-5.6652076f-5 -9.081678f-6 -0.00027135498 -0.000113193084 7.212043f-5 1.0767789f-5 -9.260971f-6 -0.00021335353 4.6696918f-5 -1.557303f-5 -1.4991401f-5 -0.00014934348 0.0001325646 5.2522977f-5 6.279714f-6 0.00023555232 -1.4920981f-5 -7.230766f-5 -3.4547393f-5 2.8906923f-5 2.2790093f-5 -3.4172224f-6 -0.000118806405 -3.52838f-5 -2.2917935f-5 -0.0002806627 5.797269f-5 -0.000105062216 -3.2559477f-5 -5.064832f-5 0.00010734205 -2.1510103f-5; -0.0001065396 1.3799787f-5 -9.49658f-7 -4.2224812f-5 0.0002788405 -0.00018578615 -0.00012964394 -3.9963037f-5 0.000117692085 1.5015337f-6 3.1457792f-5 -1.3392169f-6 9.4647025f-5 -8.079548f-5 2.057767f-5 -2.556023f-5 7.133532f-5 -8.439937f-6 0.00015499213 0.00013563226 8.3124854f-7 -2.714979f-5 9.261711f-5 -0.00015678238 1.8240835f-5 7.655584f-6 0.00021917284 8.2820065f-5 3.6963746f-5 -0.00018938481 0.00016216951 9.094203f-6], bias = Float32[-0.11257731, 0.10439747])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))Similar to most DL frameworks, Lux defaults to using Float32, however, in this case we need Float64
const params = ComponentArray{Float64}(ps)
const nn_model = StatefulLuxLayer{true}(nn, nothing, st)StatefulLuxLayer{true}(
Chain(
layer_1 = WrappedFunction(Base.Fix1{typeof(broadcast), typeof(cos)}(broadcast, cos)),
layer_2 = Dense(1 => 32, cos), # 64 parameters
layer_3 = Dense(32 => 32, cos), # 1_056 parameters
layer_4 = Dense(32 => 2), # 66 parameters
),
) # Total: 1_186 parameters,
# plus 0 states.Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses
where,
function ODE_model(u, nn_params, t)
χ, ϕ = u
p, M, e = ode_model_params
# In this example we know that `st` is am empty NamedTuple hence we can safely ignore
# it, however, in general, we should use `st` to store the state of the neural network.
y = 1 .+ nn_model([first(u)], nn_params)
numer = (1 + e * cos(χ))^2
denom = M * (p^(3 / 2))
χ̇ = (numer / denom) * y[1]
ϕ̇ = (numer / denom) * y[2]
return [χ̇, ϕ̇]
endODE_model (generic function with 1 method)Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)
axislegend(ax, [[l1, s1], [l2, s2]],
["Waveform Data", "Waveform Neural Net (Untrained)"]; position=:lb)
fig
endSetting Up for Training the Neural Network
Next, we define the objective (loss) function to be minimized when training the neural differential equations.
const mseloss = MSELoss()
function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
return mseloss(waveform, pred_waveform), pred_waveform
endloss (generic function with 1 method)Warmup the loss function
loss(params)(0.002657665015682404, [-0.026139947243405532, -0.02523518305498842, -0.024330418866571207, -0.022809917595817297, -0.0206548859159035, -0.017839104587750735, -0.014330445296866036, -0.010090402830751074, -0.005079230676427359, 0.0007404627071390518, 0.007391175204723896, 0.014860112985387667, 0.023062589494840705, 0.0317638457625263, 0.04044473304307155, 0.04803815183605482, 0.05245384121385258, 0.0497785206950642, 0.03331124389054785, -0.00588169680292122, -0.06977572449701834, -0.12393046023609698, -0.0857789980960219, 0.04685781419786133, 0.12458387445106871, 0.09204536299541258, 0.02601269552736332, -0.02151668486722114, -0.044741352501309686, -0.05164477109295164, -0.049624687702169976, -0.04329424336173226, -0.03520858766034767, -0.02671120176170369, -0.018482380756592903, -0.010850079408977382, -0.003954909585937038, 0.002158182787816188, 0.007492297847035946, 0.012072254690431104, 0.015931317726701522, 0.01910323527716701, 0.02161674612437164, 0.023498835672629836, 0.024765962250282977, 0.02542963698563249, 0.025493589131937888, 0.024954049822536312, 0.023799600361862724, 0.022011042351261455, 0.01956147802701939, 0.016416362850121043, 0.01253463267221654, 0.0078702915814919, 0.0023767594582620503, -0.00398455788336236, -0.011228977543312839, -0.019315721944120658, -0.028090351437306326, -0.03716898061973708, -0.045721988627640965, -0.052066796912763204, -0.052948004538555186, -0.04249009469732382, -0.011763827528661008, 0.046063438291663095, 0.1124191358081561, 0.11318481045970721, -0.0011900302081099701, -0.112153795768761, -0.1102263279616042, -0.047855617216699516, 0.007059497278462934, 0.0376230610661931, 0.0494544883943377, 0.050342319867101575, 0.04561756770900743, 0.03835186149299362, 0.030206446469946125, 0.022054804085371336, 0.014328524089111906, 0.007242056973463494, 0.0008704250706114314, -0.004756930826164251, -0.009654622149292282, -0.013849659433532136, -0.017362731670408034, -0.020228553850562797, -0.02246303684117339, -0.02408884499907867, -0.02511245234434812, -0.02554003206984042, -0.025366867216430625, -0.02458086492313771, -0.023165003461790526, -0.021087887998469473, -0.01831922552506407, -0.014808704094927218, -0.010513883258476445, -0.005372004668236143, 0.0006595923474436698, 0.0076272208706548005, 0.015519305399532847, 0.02425522249826769, 0.033557502897285756, 0.04280483091867297, 0.050662585919604854, 0.05447106916344226, 0.04921564851246787, 0.02653992958113439, -0.022944165766887147, -0.09319094480990875, -0.12670868212602932, -0.04491496834182019, 0.08617322953310161, 0.12127464101953365, 0.07013624148004256, 0.009940611915168535, -0.02825068253565745, -0.04568983801618356, -0.0500428745796685, -0.04730644206851969, -0.04111386254403444, -0.03349095855425922, -0.025530132616659252, -0.017792986815343644, -0.010570971694217488, -0.0039864775429672775, 0.0019022318201345877, 0.007097291824441659, 0.011602508651174782, 0.015448186112751236, 0.01864855760107953, 0.02123132672742958, 0.02320585933688209, 0.024586823200859538, 0.025374689612534515, 0.025566588660651566, 0.025151517662992603, 0.024109502090987323, 0.022411418294729257, 0.020027266194352914, 0.01690263314763123, 0.012989866569397332, 0.008224995154436423, 0.002556282223314735, -0.004086378109133462, -0.011721519477556532, -0.020317771529773456, -0.029694060900486572, -0.03940256278415151, -0.04840755711072974, -0.05457164077343472, -0.053710520657978156, -0.03835224635510637, 0.0018207767061164527, 0.06953748662187974, 0.12649009213312093, 0.08497905323613279, -0.04853150783719464, -0.12219669677859062, -0.09101047226112899, -0.029104634106277914, 0.016508892989252892, 0.04016596118374542, 0.04856877276007282, 0.04824372890161203, 0.04340779519697234, 0.03650104827493318, 0.028859648641151283, 0.021204846345245724, 0.013908907610105738, 0.007166827727450787, 0.001048038036668517, -0.0044103744886052855, -0.009212521263961293, -0.013362533864266836, -0.01689144624463532, -0.019806949851843, -0.02212608254270863, -0.023853704995913882, -0.024999145948942408, -0.025553828145059587, -0.025507006008923864, -0.02484211813216025, -0.023528147742492702, -0.021530119682262334, -0.018799369985113928, -0.015288363196071101, -0.010921025285449828, -0.005634834972958377, 0.0006432645744131962, 0.007956899694388389, 0.016330173679059944, 0.025656912034500642, 0.035623548326132085, 0.04545186059574691, 0.053465426549186096, 0.05622919525875514, 0.047289502638555136, 0.016480245757015146, -0.04436818208374012, -0.11465712729078592, -0.11380906328281767, 0.0035256747039827813, 0.11073491922660667, 0.10807839233752191, 0.04970282753406427, -0.0023824110590883216, -0.03271061868761866, -0.045761811005785864, -0.04830421641419382, -0.04514327644640337, -0.03916396408278759, -0.03199131386123907, -0.02451825658784868, -0.017233249919816677, -0.010375390048422447, -0.004074733405741349, 0.001619455561264516, 0.006681638249271333, 0.011129501643549478, 0.01496403616926888, 0.018202235085739437, 0.020849158434504245, 0.02292195420647684, 0.024415968846473114, 0.02532871524840243, 0.025649064292203448, 0.025358595011080853, 0.024429029401157633, 0.022822250367527917, 0.020497905032998816, 0.017389944681169917, 0.013436603920908612, 0.00855831588196852, 0.0026851460124512564, -0.004271471094087118, -0.012345609223740703, -0.021510920603009267, -0.03155837597536026, -0.04193965197276423, -0.05134944215185174, -0.05703529922222981, -0.05357163807155969, -0.03156122711593537, 0.019938330722203627, 0.0944638367546043, 0.12854355570889425, 0.16262327466318413])Now let us define a callback function to store the loss over time
const losses = Float64[]
function callback(θ, l, pred_waveform)
push!(losses, l)
@printf "Training %10s Iteration: %5d %10s Loss: %.10f\n" "" length(losses) "" l
return false
endcallback (generic function with 1 method)Training the Neural Network
Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
optprob, BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
callback, maxiters=1000)retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [7.131151505745947e-5; -0.00013078367919661105; -0.00022395519772544503; 8.934998186305165e-5; -0.00020378638873808086; 3.059842492803e-5; -4.698803604696877e-5; -3.4209129808004946e-5; -1.3376779861573596e-6; -1.1305796761007514e-5; 3.827170075965114e-5; 0.00010751044464996085; -4.525176700553857e-5; -2.0133122234256007e-5; -0.0002162369346478954; 2.0941157345077954e-5; 8.669662929605693e-5; 8.986749162431806e-5; -7.126640412025154e-5; 0.00010068941628560424; 1.904232703964226e-5; 0.00017213565297424793; -2.961255631817039e-5; -7.327409548452124e-5; -0.00016167385911103338; -0.00015032192459329963; -1.0876728083530907e-5; 2.985199898830615e-6; -0.0001376550062559545; 8.667418296681717e-5; 0.00011985660967184231; 3.706417919602245e-5;;], bias = [-0.39356839656829834, -0.7188471555709839, 0.11475205421447754, 0.8971724510192871, 0.8154041767120361, 0.6990301609039307, -0.23833692073822021, 0.33024537563323975, 0.6741292476654053, 0.7641496658325195, 0.08590924739837646, -0.27499890327453613, -0.3189878463745117, 0.8869849443435669, -0.8884336948394775, -0.0671079158782959, 0.4419139623641968, -0.9261888265609741, -0.09410691261291504, -0.6152958869934082, 0.3555927276611328, 0.13374674320220947, -0.9508906602859497, 0.8691732883453369, 0.29971325397491455, -0.2513458728790283, -0.5746331214904785, 0.9365338087081909, 0.2802814245223999, 0.90126633644104, -0.0513838529586792, 0.010474324226379395]), layer_3 = (weight = [5.156359839020297e-5 0.00019067975517828017 3.8886060792719945e-5 -1.7444561308366247e-5 0.00012710984447039664 1.339963887403428e-6 -8.401183731621131e-5 -5.232881085248664e-5 0.00012514341506175697 -0.00010558184294495732 0.00017082331760320812 5.2711631724378094e-5 8.062669076025486e-5 4.7760866436874494e-5 -0.00011966714373556897 -7.860972255002707e-5 6.46160333417356e-5 -7.092315354384482e-5 -2.3540869733551517e-5 4.977906064596027e-5 9.817791578825563e-5 0.00014749125693924725 7.277513986991835e-7 -0.00017526576993986964 8.114702723105438e-6 -4.2027597373817116e-5 3.159555490128696e-5 7.090783037710935e-5 -0.00018504491890780628 -0.00016893695283215493 -1.5952045941958204e-5 0.00014143857697490603; -2.275540737173287e-6 6.09611815889366e-5 -1.4450851267611142e-5 9.357666385767516e-6 2.193824457208393e-6 2.8449256205931306e-5 -0.00013796848361380398 6.290333840297535e-5 5.1286740927025676e-5 4.455972521100193e-5 0.00010166926949750632 8.726161468075588e-5 -0.00013558505452238023 0.00016659846005495638 8.248985068348702e-7 -1.0119376383954659e-5 -0.0002259498869534582 2.1480776922544464e-5 -0.00012437268742360175 -0.00019795272964984179 5.4754353186581284e-5 8.063705899985507e-5 -9.091036190511659e-5 0.00017693883273750544 -6.437947740778327e-5 0.00010441896301927045 1.4789691704208963e-5 -4.474387606023811e-5 -9.914927795762196e-5 3.8010279240552336e-5 3.832571383100003e-5 -3.043590550078079e-5; -8.543570584151894e-5 2.6426152544445358e-5 -0.00010575642227195203 -2.3693242837907746e-5 -0.00018920551519840956 -6.982802733546123e-5 3.429539719945751e-5 7.599106902489439e-5 -2.4642678909003735e-5 -4.5045759179629385e-5 -5.918285751249641e-5 5.2414143283385783e-5 0.00013944115198682994 -2.8318241675151512e-5 -4.244231240591034e-5 0.000167592748766765 0.00017747799574863166 -0.000182871037395671 -0.00011570445349207148 -6.38922574580647e-5 0.00013647739251609892 0.00015706494741607457 5.478590537677519e-5 -8.79626750247553e-5 5.3103438403923064e-5 3.2038493372965604e-5 -3.8827645767014474e-5 0.00012415676610544324 3.662657763925381e-5 1.2301276910875458e-5 -0.00011112621723441407 -2.9500992241082713e-5; -7.253411604324356e-5 8.789086132310331e-5 0.00011558875849004835 -5.6391581892967224e-5 -0.00015464989701285958 2.5663561245892197e-5 -0.00011515457299537957 -0.0001396962907165289 4.520812217378989e-5 4.0473030821885914e-5 -0.00012583767238538712 0.0001281375443795696 2.0328703612904064e-5 -3.156050297548063e-5 -1.640689697524067e-5 -7.862273196224123e-5 0.00013978304923512042 6.497126014437526e-5 2.9592854843940586e-5 3.997813109890558e-5 -5.714295184588991e-5 -3.577535244403407e-5 -7.009574619587511e-5 -7.170435856096447e-5 -2.3521610273746774e-5 -2.9568152513093082e-6 -0.00017784553347155452 -0.00021148323139641434 9.792210039449856e-5 -4.573704063659534e-5 9.450320249015931e-6 1.043225438479567e-5; -8.425130363320932e-5 -9.018260607263073e-5 0.00016580193187110126 -0.00022687231830786914 0.00013576587662100792 -2.323028638784308e-5 4.5834291086066514e-5 0.00010316733096260577 4.271070065442473e-5 5.614725159830414e-5 0.00011300598998786882 -3.716586434165947e-5 -0.00014973875659052283 -1.8655988242244348e-5 -6.741919060004875e-5 -6.399850099114701e-5 8.252401312347502e-5 3.643983291112818e-5 -2.633856092870701e-5 -1.2701409104920458e-5 -0.00010382018081145361 0.00023030483862385154 4.0265099698899576e-8 3.578784799174173e-6 -4.729699867311865e-5 -9.030117507791147e-5 -0.00010830732207978144 4.85725240650936e-6 -2.2438813175540417e-5 0.00019308683113195002 3.132610800093971e-5 -6.138355729490286e-6; 0.00017715721332933754 6.444964674301445e-5 -2.79735577350948e-5 0.00013668292376678437 -0.00019245082512497902 7.854711293475702e-5 -0.00011521901615196839 -7.829027163097635e-5 -6.073626718716696e-5 8.497275121044368e-5 -0.00016530374705325812 -5.708709068130702e-5 8.075969526544213e-5 -0.000102479629276786 4.1885472455760464e-5 -0.00020741918706335127 -0.00021628914691973478 -7.611844466737239e-6 0.00015422630531247705 0.00012890920334029943 -6.859296263428405e-5 -8.116226672427729e-5 -6.760940595995635e-5 -3.0659888579975814e-5 -0.00014312427083496004 -0.00012147762026870623 -8.487812738167122e-5 -0.0002090919588226825 -1.5426139725605026e-5 -4.386228465591557e-5 0.00016672314086463302 -0.00013333756942301989; -1.7027012290782295e-5 -7.838660530978814e-5 -2.7537635105545633e-6 -3.4266071452293545e-5 -0.00010930719872703776 -1.0140402082470246e-5 0.00021774104970972985 -3.442178058321588e-5 6.533635314553976e-5 -5.2558709285221994e-5 4.414929571794346e-5 -9.469506767345592e-5 -9.378704999107867e-5 -0.0002685293438844383 -4.620100662577897e-5 -1.8231337890028954e-5 -7.643468416063115e-5 5.950895501882769e-5 -5.101390343043022e-5 6.0190617659827694e-5 -0.00033646111842244864 0.00013251538621261716 -0.00011883546540047973 -7.533064490417019e-5 6.57798518659547e-5 0.00016976856568362564 0.00014797730545978993 5.476844307850115e-5 0.00022539716155733913 1.708144736767281e-5 4.705919491243549e-5 -4.238273686496541e-5; 0.00025345009635202587 3.166939859511331e-5 0.00014633496175520122 3.902443859260529e-5 2.5036346414708532e-5 2.5601689230825286e-6 -6.528801168315113e-5 5.236720244283788e-5 0.0002026132569881156 -4.8072888603201136e-5 6.290341843850911e-6 -0.00017363097867928445 -0.00010943003871943802 -5.395151674747467e-6 6.216693145688623e-5 2.3805858290870674e-5 0.0002844385162461549 -6.578316970262676e-5 -5.2412804507184774e-5 -9.748831507749856e-5 -1.344160864391597e-5 -0.00010455514711793512 0.0002267369709443301 8.53522215038538e-5 1.9609544324339367e-5 2.044877646767418e-6 -0.00010194712376687676 -0.00013655076327268034 -0.00013482767099048942 -4.628490569302812e-5 -3.927193029085174e-5 -4.0323357097804546e-5; -0.00010294240200892091 8.638996951049194e-5 0.00016807053179945797 -0.00020801310893148184 0.0001546746352687478 -0.00014960135740693659 -3.733771154657006e-5 -0.00018042865849565715 -8.03367147454992e-5 2.040314211626537e-5 3.1480408324569e-6 -7.403209110634634e-6 -0.00010296026448486373 0.00015524947957601398 0.00016167599824257195 -0.00010011806443799287 9.599006443750113e-5 0.00018563908815849572 7.097602065186948e-5 -2.577121449576225e-5 7.736137922620401e-5 9.409066842636093e-5 -7.980458030942827e-5 -1.3846307410858572e-5 -4.509392238105647e-5 -5.8922691096086055e-5 -7.93353610788472e-5 6.475036207120866e-5 -8.791488653514534e-5 -3.929892773157917e-5 0.0001495394535595551 7.421398913720623e-5; 0.00026504372362978756 -7.69282087276224e-6 -0.00018104819173458964 1.7321714040008374e-5 3.884220404870575e-6 0.00015996047295629978 -3.724328053067438e-5 -4.120950325159356e-5 0.00016032571147661656 -6.307647709036246e-6 -5.226865687291138e-5 8.155228715622798e-5 -4.194629946141504e-5 -5.225798668107018e-5 4.120900484849699e-5 -7.801693573128432e-5 -3.551164627424441e-5 0.0001470997667638585 -7.31532127247192e-5 -0.00012653518933802843 2.2044016077416018e-5 -0.0002652945404406637 -6.87541687511839e-5 0.00027331389719620347 -0.00016666989540681243 1.8141718101105653e-5 -0.00012184365914436057 5.9049409173894674e-5 0.00019336976401973516 -7.755622209515423e-5 6.685210973955691e-5 5.414752195065375e-6; -0.00011702736810548231 -9.039788710651919e-5 7.959983486216515e-5 0.0001097230997402221 6.040323569322936e-5 7.36871370463632e-5 -4.97950823046267e-5 -7.827103399904445e-5 7.882439604145475e-6 -5.489561954163946e-5 -6.127762753749266e-5 5.420421985036228e-6 0.00012381485430523753 0.00020219804719090462 0.00016511863213963807 3.387339893379249e-5 7.349871884798631e-5 -8.028540469240397e-5 0.00019781553419306874 6.586289237020537e-5 5.572840746026486e-5 -0.00010337465209886432 7.301945879589766e-5 4.857040403294377e-5 -0.00019983021775260568 -7.208172610262409e-5 -7.724406896159053e-5 7.878362339397427e-6 7.403608469758183e-5 -1.7159665731014684e-5 -0.00015000678831711411 4.670558701036498e-5; 1.2934280675835907e-5 -6.456909886765061e-6 -2.4318185751326382e-5 -0.00010313167877029628 2.719049916777294e-5 0.00012322996917646378 3.379095869604498e-5 6.575505540240556e-5 -8.253914711531252e-5 -0.00020744190260302275 -2.5740866476553492e-5 -9.228922135662287e-5 0.0001649452606216073 8.3336271927692e-5 6.481775199063122e-5 -9.077947470359504e-5 0.00016111430886667222 -0.00015401570999529213 -5.092590981803369e-6 5.712970596505329e-5 0.00014002480020280927 -1.13437508844072e-5 -4.308610732550733e-5 3.6881931009702384e-5 0.00012544241326395422 5.342317672329955e-5 3.193530938006006e-5 -6.195373862283304e-5 -0.0001125667622545734 -1.0741084224719089e-5 -0.00015544377674814314 -0.00016407493967562914; 1.7903052139445208e-5 7.599344371556072e-7 -8.874756929344585e-9 -9.888760541798547e-5 1.1979430780684197e-7 3.5137538361595944e-5 -0.00012919632717967033 2.574813879618887e-5 -2.958733239211142e-5 1.8644905139808543e-5 -5.426894858828746e-5 0.0001428186078555882 -6.12571238889359e-5 -3.950946847908199e-5 9.879006393020973e-5 -1.3973963177704718e-5 -0.00016492867143824697 -1.6984584362944588e-5 -7.545234257122502e-5 0.00015263255045283586 -0.00013809339725412428 -5.818481440655887e-5 4.025372618343681e-5 -5.273207716527395e-5 -4.274570983398007e-6 -0.00012813112698495388 6.091036993893795e-5 0.0002514176012482494 -6.452832167269662e-5 -3.223114981665276e-5 -0.0001335370761808008 -0.00024113774998113513; 0.00019302856526337564 0.00011559153790585697 9.253919415641576e-5 3.297725925222039e-5 2.854438753274735e-5 3.812321665463969e-5 8.066637383308262e-5 7.3173064265574794e-6 -4.465677557163872e-5 -0.00020067770674359053 1.5995916328392923e-5 -7.412855484290048e-5 5.159095235285349e-5 -6.007593583490234e-6 -0.00013019147445447743 -0.00016384870104957372 -1.4430496548811789e-6 -0.00021935337281320244 -0.00010655698861228302 0.00021270127035677433 0.0001089227880584076 0.00010242690041195601 -4.540502413874492e-5 -3.939270754926838e-5 3.112285048700869e-5 -3.2462121453136206e-5 -0.0001262575970031321 3.88536172977183e-5 0.00015670087304897606 -2.5878309315885417e-5 -8.950286428444088e-5 9.295035852119327e-5; 4.8602334572933614e-5 1.2288691323192324e-5 0.00012927591160405427 -2.3040704036247917e-5 4.547274511423893e-5 2.0581916032824665e-5 -7.785265188431367e-5 -3.456328340689652e-5 -8.642450120532885e-5 -3.4536089970060857e-6 -1.564373087603599e-5 -1.7107720850617625e-5 -8.66454211063683e-5 -2.871940887416713e-5 0.0001002334029180929 2.1782077510579256e-6 -3.707761788973585e-5 8.696902114024851e-6 0.00013740037684328854 -0.00016546921688131988 8.786544640315697e-6 -1.7026919522322714e-5 3.2561420084675774e-5 -6.720823148498312e-5 0.00018260040087625384 1.9770190192502923e-5 9.498216968495399e-5 0.00020699790911749005 2.2627935322816484e-5 0.00021628029935527593 -0.00019490830891299993 -6.437905540224165e-5; -6.341886182781309e-5 7.524494139943272e-5 -3.1900843168841675e-5 5.2269886509748176e-5 5.954971493338235e-5 4.3701733375201e-5 -9.535151184536517e-5 8.666117355460301e-5 -5.968932327959919e-6 0.00016350859368685633 0.00023269570374395698 9.531038085697219e-5 -4.965401672052394e-7 0.00011375308531569317 -8.599618922744412e-6 0.00021408669999800622 -0.0001900163624668494 -1.3367693100008182e-5 -0.00012722209794446826 6.847290205769241e-5 -0.00016786226478870958 5.029947351431474e-5 0.00013204518472775817 0.00014563680451828986 -4.837989035877399e-5 -7.434260623995215e-5 6.757258961442858e-5 2.9210579668870196e-5 0.0001970800367416814 -7.945372090034653e-6 -0.00015771714970469475 -9.15919299586676e-5; -1.5075296687427908e-5 0.0001535158371552825 -0.00017725472571328282 -6.718356598867103e-5 -0.00017399278294760734 2.58370823758014e-6 0.00015861782594583929 0.00013679488620255142 9.605943341739476e-5 4.318821538618067e-6 0.00012708887516055256 0.00015925760089885443 -2.4016282623051666e-5 -2.5003660994116217e-5 -0.00022581941448152065 1.823992352001369e-5 6.756676157237962e-5 0.00016922189388424158 2.385108200542163e-5 -5.1032810006290674e-5 0.00010266096796840429 -0.00021631190611515194 -0.00012818959658034146 -1.1127934158139396e-5 6.346975169435609e-6 0.000229330500587821 -8.441721547569614e-6 -5.780874516858603e-7 -0.00011231700773350894 -8.032164623728022e-5 -9.255215263692662e-5 9.290668822359294e-5; -0.000136871516588144 -5.965710442978889e-5 7.443503272952512e-5 -5.125577445141971e-6 7.898058538557962e-5 0.0001950430596480146 -2.7452659196569584e-5 0.0001406589144608006 -0.00013338318967726082 5.61087581445463e-5 -5.949925252934918e-5 -3.716514402185567e-5 -3.066491262870841e-5 0.00020231938106007874 -5.209029768593609e-5 9.541358303977177e-5 4.455624002730474e-5 -4.7847675887169316e-5 -0.00010344690963393077 -0.0001242331345565617 0.00019607276772148907 -8.475691720377654e-5 -3.545081563061103e-5 2.3854140636103693e-6 -0.00019006896764039993 2.9169092158554122e-5 0.00015010751667432487 0.00020149961346760392 6.121338810771704e-5 1.1145612006657757e-5 9.872864029603079e-5 5.0263392040506005e-5; 1.2660382708418183e-5 4.336226265877485e-5 -2.608302020234987e-5 7.266502507263795e-5 8.630564116174355e-5 8.061679545789957e-5 6.847788608865812e-5 -8.499907562509179e-5 -0.00010012563143391162 -0.0001504147658124566 7.534837641287595e-5 -3.045714220206719e-5 -7.3834256909322e-5 -7.829157402738929e-5 0.00020630266226362437 0.00010878698230953887 -4.620384788722731e-5 2.9062503017485142e-5 -1.3840483006788418e-5 -0.00023476140631828457 -0.00010321206355001777 2.9436221666401252e-5 9.178899199469015e-5 -0.00018288232968188822 -0.0001256522082258016 2.6628791601979174e-5 -7.870322588132694e-5 -3.3296299079665914e-5 -2.1203355572652072e-5 -9.568237146595493e-5 0.00013806378410663456 6.071791631256929e-6; -9.467780910199508e-5 5.493052867677761e-6 3.175126266796724e-6 -3.070318780373782e-5 1.0468203072377946e-5 6.947026849957183e-5 -4.231423736200668e-5 0.00018103778711520135 8.548808546038345e-5 0.00016353160026483238 0.00013080153439659625 1.763648833730258e-5 -0.00011123169679194689 -8.080308907665312e-5 -6.342014239635319e-5 5.928953760303557e-5 -3.068616570089944e-5 -0.00014446199929807335 -8.117519428196829e-6 6.107920489739627e-5 -0.00015054238610900939 2.8523321816464886e-5 0.00013776110426988453 -0.00023868631978984922 -0.00012314986088313162 -3.760907202376984e-5 0.0001125059716287069 0.00014531031774822623 -9.960775059880689e-6 0.00018910352082457393 7.195069338195026e-5 -0.00010065511742141098; -0.00010962554370053113 8.291353879030794e-5 -0.00021776327048428357 4.668173914978979e-6 0.00018954645202029496 -9.791765478439629e-5 7.895516318967566e-5 8.536226232536137e-5 0.00011110329796792939 0.00012978637823835015 -7.130380981834605e-5 -7.316371193155646e-5 9.594659059075639e-5 8.964981134340633e-6 6.684039544779807e-5 -0.00010442598431836814 -5.329816849553026e-5 9.935316484188661e-5 -3.0483559385174885e-5 5.100475391373038e-5 -4.183334385743365e-5 9.982552728615701e-5 -6.407970795407891e-5 -0.00010264213779009879 1.890852399810683e-5 0.00012266771227587014 -4.741976590594277e-5 0.00020690035307779908 -0.00010621598630677909 2.4675746317370795e-5 0.00013285299064591527 -0.00019623646221589297; -8.14312879811041e-5 0.00011926063598366454 -4.7851532144704834e-5 0.0001988400617847219 -0.00012450723443180323 5.464345667860471e-5 4.3390093196649104e-5 7.090102735674009e-5 4.2525520257186145e-5 -2.73481191470637e-6 -3.206872133887373e-5 -0.00017025585111696273 5.923801290919073e-5 -4.0854865801520646e-5 7.689074664085638e-6 -0.0001467990077799186 6.8385306803975254e-6 9.353524546895642e-6 0.00010305149044143036 0.00011149377678520977 -8.45801196192042e-7 -9.8596399766393e-5 7.637111411895603e-5 -6.793784450564999e-6 -5.4873926274012774e-5 8.346189133590087e-5 0.0001837978488765657 0.00011431792518123984 7.142563117668033e-5 -2.81204006569169e-6 -8.54917016113177e-5 0.00014911332982592285; 6.800088885938749e-5 -6.13871161476709e-5 7.415809523081407e-5 -7.51013430999592e-5 0.00010910239507211372 0.00015358082600869238 -3.283990372437984e-5 8.240390161518008e-5 -0.00017121822747867554 -0.0001260863646166399 9.68988115346292e-6 1.5132873159018345e-5 -7.2982365963980556e-6 4.679658741224557e-5 -2.7416606826591305e-5 0.00016622580005787313 -1.5386422091978602e-5 1.1667392755043693e-5 -6.853812374174595e-5 5.38561653229408e-5 -2.52642003033543e-5 2.1459325580508448e-5 -0.00012681745283771306 -2.075959309877362e-5 -0.0001655579690122977 0.0001124712944147177 -0.00020873191533610225 4.731085937237367e-5 -7.806008943589404e-5 1.6923895600484684e-5 -8.666857320349663e-5 -6.354582728818059e-5; 1.605590477993246e-5 -9.214781312039122e-5 3.930803723051213e-5 3.339282193337567e-5 -2.2064250515541062e-5 -0.0001116427083616145 -0.0001105472692870535 -2.9351091143325903e-5 0.00026300951140001416 -2.4370649043703452e-5 -6.756846414646134e-5 -1.3051280802756082e-5 1.2546349353215192e-5 -0.00013098583440296352 0.0002212569088442251 -0.0002456110087223351 -0.00013754221436101943 -6.024431968398858e-6 0.00010165593994315714 -0.00010347491479478776 8.12050566310063e-5 -0.00011150022328365594 -5.101117039885139e-6 5.1772362894553225e-6 -4.5687320380238816e-5 5.250840331427753e-5 7.424168416036991e-6 5.5018543207552284e-5 -6.672957533737645e-5 1.6823934856802225e-5 -5.916389227422769e-7 0.00011691221880028024; 0.00016468639660160989 -4.142229045100976e-6 1.6305993995047174e-5 4.338502094469732e-6 5.628872531815432e-5 -5.3282856242731214e-5 9.931611566571519e-5 7.054002344375476e-5 0.00014977036335039884 9.622370271245018e-5 -0.00013130984734743834 -8.779023482929915e-5 4.96920220030006e-5 -0.00022217306832317263 1.8167891539633274e-5 0.0001069377685780637 -0.00012857520778197795 -2.3393142328131944e-5 9.306996071245521e-5 -6.0381167713785544e-5 -8.323797374032438e-5 4.5449927711160854e-5 -1.69510403793538e-5 -2.958697223220952e-5 0.0002056476369034499 -4.5735618186881766e-5 3.083734554820694e-6 -4.606014408636838e-5 -4.0424631151836365e-5 -0.00010462451609782875 2.6280902602593414e-5 -2.8668267987086438e-5; 1.9805096599156968e-5 1.4936950719857123e-5 -1.691809120529797e-5 -5.885760401724838e-5 -9.012574446387589e-5 -0.0001729521172819659 5.3509065764956176e-5 7.940650539239869e-5 -7.841360638849437e-5 -3.458663195488043e-5 -0.00023340428015217185 -0.00010877197928493842 -0.00011123692092951387 0.00016576774942222983 0.00015244641690514982 0.00023112882627174258 -0.00013181641406845301 -3.652520490504685e-6 0.0001981431123567745 -0.0002039439423242584 0.00018931507656816393 7.898527837824076e-5 -6.521945033455268e-5 7.485369133064523e-5 -0.0001417256862623617 -6.618673069169745e-5 2.276073064422235e-5 0.0002242559421574697 -0.00022733074729330838 5.999006316415034e-5 -4.476529284147546e-5 0.000154901368659921; -0.00018362187256570905 2.8344835300231352e-5 5.875679198652506e-5 -9.042311830853578e-6 0.0001583642151672393 -0.00012203656660858542 -0.00013775906700175256 -8.766842074692249e-5 7.999336230568588e-5 -0.00015464850002899766 -6.200328243721742e-6 5.8369590988149866e-5 7.971576997078955e-5 -4.6827553887851536e-5 7.362258475041017e-5 -0.00012464872270356864 0.00010568612196948379 -2.3056836653267965e-6 2.0719720851047896e-5 0.00012883238377980888 1.4306578123068903e-5 -0.00010382722393842414 -0.00011536503006936982 -7.477475446648896e-5 -2.8425611162674613e-6 -4.963789615430869e-5 -2.0438847059267573e-7 -5.8167861425317824e-5 1.1013881703547668e-5 0.000135424270411022 -8.206035272451118e-5 -0.00010210004984401166; -0.00011804371752077714 4.879320476902649e-5 0.00021875354286748916 0.0002522997965570539 -4.4419535697670653e-5 3.124824070255272e-5 -9.218024933943525e-5 -1.9714538211701438e-5 -5.2572791901184246e-5 4.449607149581425e-5 3.879932046402246e-5 4.951702430844307e-5 -1.3214127648097929e-5 -0.00012538852752186358 9.52649861574173e-6 1.100346435123356e-5 0.00013922595826443285 0.00013214659702498466 2.3089327783054614e-7 0.00012948241783306003 -8.039539534365758e-5 0.00018970041128341109 1.4562219803337939e-5 0.00011223969340790063 5.37329469807446e-5 0.00027434396906755865 -7.338939030887559e-5 0.00011717034794855863 1.5658158645237563e-6 0.00018022941367235035 4.733516834676266e-5 -5.2879902796121314e-5; -6.222676893230528e-5 0.0002775734756141901 3.443772584432736e-6 -1.4459124031418469e-5 -4.213319243717706e-6 9.586807209416293e-6 5.862694524694234e-5 6.916483835084364e-5 1.932861232489813e-5 4.818578963750042e-5 0.00013721796858590096 -0.0001890058338176459 0.00010223389108432457 -4.083046587766148e-5 2.2576074115931988e-5 -0.00013741024304181337 -6.52634771540761e-5 -5.948753823759034e-5 9.239703649654984e-5 0.00011669005471048877 -4.361863830126822e-5 5.606485137832351e-5 3.915324487024918e-5 0.00013993436004966497 -6.717593350913376e-5 -5.531340138986707e-5 6.657954509137198e-5 -5.665272692567669e-5 7.904965968918987e-6 -0.00018478406127542257 0.0001756436686264351 -1.0333412319596391e-5; -0.00010444985673530027 -0.00010236347588943318 6.608462717849761e-5 0.00011314458242850378 0.00013857914018444717 -0.00011828510469058529 4.603849083650857e-5 -4.089863068656996e-5 -8.135885582305491e-5 -1.874193549156189e-5 -5.390911246649921e-5 -8.040257739594381e-7 -4.9120695621240884e-5 7.67628152971156e-5 -0.0001234796509379521 -7.100364746293053e-5 8.747048559598625e-5 -2.902499181800522e-5 -0.00010388233931735158 0.0002226145297754556 8.514171349816024e-5 0.00019381266611162573 8.092069037957117e-5 4.39742834714707e-5 2.3634469471289776e-5 -3.086109791183844e-5 2.1804104108014144e-5 -1.1330622328387108e-5 -2.9188047847128473e-5 8.988073386717588e-5 5.233488991507329e-6 4.098313002032228e-5; -4.989862645743415e-5 -6.455292896134779e-5 2.42199057538528e-5 -0.00010985318658640608 -6.856126856291667e-5 -0.00010198281961493194 -7.982736133271828e-5 -5.4867396102054045e-5 6.779420073144138e-5 -0.00017294238205067813 -8.94746626727283e-5 4.323278335505165e-5 -8.979545782494824e-6 0.00020485180721152574 -0.0002699383476283401 3.590197957237251e-5 -7.731442019576207e-5 -5.620169758913107e-6 -3.9574155380250886e-5 -9.64802602538839e-5 0.00014572379586752504 -4.025620728498325e-5 -9.298345685238019e-5 2.925133412645664e-5 -8.965071174316108e-5 -0.0001528637803858146 -0.0001119524022215046 0.00010093515447806567 8.729911314730998e-6 9.716435306472704e-5 7.367786747636274e-5 -7.197131344582886e-5; 0.00013874641444999725 7.027042738627642e-5 -4.926845576846972e-5 5.407356366049498e-5 5.181305095902644e-5 -8.973737567430362e-5 -9.77306262939237e-5 -1.0230601219518576e-5 8.33973717817571e-6 -2.2131924197310582e-5 7.285369065357372e-5 1.0869996003748383e-5 -1.2246420737938024e-5 -5.906746446271427e-5 0.0001458988117519766 0.000178064830834046 -9.759460226632655e-6 -2.5184879632433876e-5 5.834676994709298e-5 0.00016339229478035122 -4.9666974518913776e-5 -5.386127304518595e-5 0.00011604426254052669 -0.00016577377391513437 -0.0001484101958340034 -3.726542490767315e-5 -3.0266250178101473e-5 9.583700011717156e-5 3.117765299975872e-5 2.505421252863016e-5 7.085471588652581e-5 -0.00010311305231880397], bias = [0.14560095965862274, 0.14954809844493866, -0.0068453224375844, 0.13693609833717346, -0.12243223190307617, 0.06423419713973999, -0.09587834775447845, 0.13855601847171783, -0.14416959881782532, -0.1295659840106964, 0.017134612426161766, 0.032455600798130035, 0.16777649521827698, -0.0061350008472800255, 0.012148765847086906, 0.14453285932540894, -0.16991734504699707, 0.05646685138344765, 0.11063890904188156, -0.0802336037158966, 0.07706256210803986, 0.005536093842238188, 0.012314255349338055, 0.021647317335009575, 0.01715477928519249, 0.015032263472676277, 0.14215363562107086, -0.11218428611755371, -0.05650659650564194, -0.06697364151477814, -0.1268003284931183, -0.10132648050785065]), layer_4 = (weight = [-5.665207572747022e-5 -9.081678399525117e-6 -0.00027135497657582164 -0.00011319308396195993 7.212042692117393e-5 1.0767788808152545e-5 -9.260970728064422e-6 -0.00021335353085305542 4.669691770686768e-5 -1.5573030395898968e-5 -1.4991401258157566e-5 -0.0001493434829171747 0.0001325646007899195 5.252297705737874e-5 6.279713943513343e-6 0.00023555231746286154 -1.4920980902388692e-5 -7.230765913845971e-5 -3.454739271546714e-5 2.890692303481046e-5 2.2790092771174386e-5 -3.4172223877249053e-6 -0.00011880640522576869 -3.528379966155626e-5 -2.2917934984434396e-5 -0.0002806627016980201 5.797269113827497e-5 -0.00010506221588002518 -3.255947740399279e-5 -5.0648319302126765e-5 0.00010734204988693818 -2.151010266970843e-5; -0.00010653959907358512 1.3799786756862886e-5 -9.496579878032207e-7 -4.222481220494956e-5 0.000278840510873124 -0.0001857861498137936 -0.0001296439440920949 -3.9963037124834955e-5 0.00011769208504119888 1.5015336884971475e-6 3.145779191982001e-5 -1.3392168511927593e-6 9.464702452532947e-5 -8.079547842498869e-5 2.057767051155679e-5 -2.5560229914844967e-5 7.133532199077308e-5 -8.439937118964735e-6 0.00015499212895520031 0.00013563226093538105 8.312485419992299e-7 -2.7149790184921585e-5 9.261711238650605e-5 -0.00015678237832617015 1.8240834833704866e-5 7.655584340682253e-6 0.00021917284175287932 8.282006456283852e-5 3.696374551509507e-5 -0.00018938480934593827 0.00016216951189562678 9.094203051063232e-6], bias = [-0.11257731169462204, 0.10439746826887131]))Visualizing the Results
Let us now plot the loss over time
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")
lines!(ax, losses; linewidth=4, alpha=0.75)
scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)
fig
endFinally let us visualize the results
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(compute_waveform(
dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)
l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
s3 = scatter!(ax, tsteps, waveform_nn_trained; marker=:circle,
alpha=0.5, strokewidth=2, markersize=12)
axislegend(ax, [[l1, s1], [l2, s2], [l3, s3]],
["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
position=:lb)
fig
endAppendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 128 × AMD EPYC 7502 32-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
JULIA_CPU_THREADS = 16
JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 16
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = LiterateThis page was generated using Literate.jl.