Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux, ComponentArrays, LineSearches, AMDGPU, LuxCUDA, OrdinaryDiffEq, Optimization,
      OptimizationOptimJL, Printf, Random, SciMLSensitivity
using CairoMakie

CUDA.allowscalar(false)

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end
one2two (generic function with 1 method)

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end
soln2orbit (generic function with 2 methods)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end
d_dt (generic function with 1 method)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0)<1e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio1 "mass_ratio must be <= 1"
    @assert mass_ratio0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end
compute_waveform (generic function with 2 methods)

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defiing a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(Base.Fix1(broadcast, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1e-4)),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1e-4)),
    Dense(32 => 2; init_weight=truncated_normal(; std=1e-4)))
ps, st = Lux.setup(Xoshiro(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-8.867039f-5; 7.416554f-5; -8.3523424f-5; 1.341447f-5; -0.00018793318; -5.3853f-5; 0.00017667204; -3.99257f-5; -3.1561274f-5; 0.000118529904; 8.573634f-5; 8.974296f-5; 4.237295f-5; 5.0889263f-5; 6.988628f-5; -0.00011852392; -5.109255f-6; 6.824975f-5; -8.960609f-5; -0.00011109935; 4.0976673f-5; -8.3042905f-5; -8.961094f-6; -2.3573835f-5; -0.00013556019; 6.406412f-5; -0.00016370116; 9.30153f-6; -0.00012561903; -7.946415f-5; -8.3233404f-5; 0.00015809275;;], bias = Float32[0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0;;]), layer_3 = (weight = Float32[-2.4663226f-5 5.46967f-5 0.0001463716 -3.4480883f-5 -4.037039f-5 5.405099f-5 7.297681f-5 -0.00011844118 0.000104734085 -0.00011177562 0.00020245991 0.00020026896 8.797424f-5 3.589162f-5 6.781178f-5 8.262759f-5 0.00012611314 1.4971126f-5 -0.00022475456 -0.0001824752 4.237918f-5 -0.00010268068 0.00013310954 6.364162f-6 -6.0619288f-5 2.3454644f-5 -0.00011768659 -5.4191365f-5 0.00015187183 -0.00014901615 -3.211651f-5 -0.00018643119; 8.400931f-5 8.190448f-5 -7.5795346f-5 -0.00014490525 -8.026669f-5 0.00013214833 6.3365565f-5 7.2070325f-5 1.4029109f-5 2.1522306f-5 -7.369554f-5 -0.00014524483 -3.151979f-5 -8.233094f-5 -6.0942082f-5 2.3419234f-5 -3.5395453f-5 6.8046857f-6 -5.955497f-5 4.353641f-5 -2.763815f-5 0.0001408046 -3.6216014f-5 -4.52484f-5 -3.9989267f-5 4.9091854f-5 9.5769765f-6 0.00015195814 5.9358943f-5 -9.657485f-5 -3.1359814f-5 1.1793401f-5; -5.2539603f-5 0.00011845507 -0.00024306904 9.697222f-6 8.8296954f-5 0.00010465296 2.8727289f-5 -5.5131528f-5 0.00012723249 -0.000107112406 -7.708885f-5 5.5617467f-5 -1.9370278f-5 -2.4160769f-5 0.000115285 -3.1368716f-5 0.00014661191 -8.9497444f-5 0.00012131467 -0.00013116146 -7.7973884f-5 0.00017541114 1.7877275f-5 -1.7823138f-5 0.00017857323 -2.9119034f-5 0.00017016496 5.5948192f-5 0.00012793126 -5.9447888f-5 0.00014498932 9.690766f-6; 2.955162f-5 0.00016055409 -0.00014412394 -3.269356f-5 -6.6276494f-5 -3.4136767f-5 2.472136f-5 -4.6417932f-5 -0.00020297938 -3.5066507f-5 1.0180169f-5 -5.085633f-5 0.00015109706 8.5875516f-5 0.00021773236 -3.492714f-5 -0.00013812026 8.3580504f-5 9.752847f-5 -2.2259997f-6 -8.1014725f-5 -8.8121f-5 -5.857862f-5 -2.123647f-5 -3.2052554f-5 -0.00011200857 -0.00015449007 4.754837f-5 -8.56635f-5 -3.266894f-5 0.0001299648 0.000151268; 1.2165002f-5 -1.9906167f-6 -4.478767f-6 1.6926488f-5 0.00016864626 -0.00012150254 1.8440422f-5 -6.134338f-5 0.000120134064 -3.8363352f-5 3.0897514f-5 3.503678f-5 6.470618f-5 -2.832286f-5 -3.4378252f-5 -1.9018078f-5 -6.267512f-5 0.00010506202 -4.0592833f-5 -9.230559f-5 0.00017425604 -0.00023758337 7.999391f-6 -2.1674214f-5 9.792499f-5 -3.700005f-5 -8.174925f-5 -6.200885f-5 0.00021389616 -5.5894016f-5 -0.0001105185 -8.046602f-5; -6.7915054f-5 -0.00014821188 8.422408f-5 -5.8989535f-5 6.697295f-5 -4.2998992f-5 0.00017475973 -0.00036095522 -4.7272846f-5 6.996583f-5 -8.0235644f-5 -0.00011565367 -4.4686083f-5 6.0914543f-5 -1.9955673f-5 6.2838815f-5 -0.000107378655 -3.1186006f-5 -0.00018124704 2.7625203f-5 -0.000110388944 -8.9372064f-5 4.4110577f-5 9.2353686f-5 7.072233f-5 4.2070933f-5 6.544264f-5 -5.849988f-5 -0.0001502367 5.615561f-5 -3.0056055f-5 -7.770206f-6; 1.9214807f-5 -3.9501316f-5 1.8342041f-6 -0.00014592486 0.000105537816 -4.4329023f-5 5.016228f-5 1.3590231f-5 8.224198f-5 0.00013017692 0.00014121264 -0.00015685085 5.003892f-5 -0.00010419413 -0.00013356496 -8.686779f-6 0.00013745339 -1.8051629f-5 -0.00015203668 -0.00019363403 3.4004395f-5 2.3496219f-5 0.00012532654 0.00012756189 -4.6274625f-5 7.178278f-5 6.202943f-5 9.5346803f-7 4.0353694f-5 5.9009264f-5 4.8922175f-5 4.169317f-5; 6.0920232f-5 0.00010437742 0.00015709484 1.13517945f-5 8.456238f-5 -0.00011710144 8.596785f-5 3.535803f-6 -8.165141f-5 -3.942573f-5 -7.320776f-5 -2.447974f-6 1.5222952f-5 3.9299935f-5 -1.7043714f-5 6.51308f-5 -0.00016239702 7.0561575f-5 -3.7803784f-5 -0.0001811208 -0.00015309146 -0.00022680692 8.786712f-5 5.8372825f-5 2.8744682f-5 2.9284458f-6 0.00014017696 -0.0001198378 4.354081f-5 -0.00010086141 0.00012593632 0.00011232371; -4.659168f-5 0.00020566257 0.00014050507 -1.8498888f-5 -9.801278f-5 -0.00019698056 0.0001382167 1.161983f-5 -4.1388113f-5 -1.9489507f-5 -0.000107232954 4.699142f-5 -0.00012979495 4.284331f-5 -0.00011468392 7.255112f-5 -4.5639114f-5 0.00012155309 7.5814976f-5 3.552371f-5 0.000117173506 9.6490796f-5 0.0001746075 4.11201f-5 0.00021681857 -6.210224f-6 3.3665496f-5 2.1972988f-5 3.0656127f-5 -8.914562f-5 3.4218192f-5 5.926755f-6; 0.00023464004 2.9943514f-5 2.4276895f-5 -0.000101582904 8.745663f-5 0.000117331925 5.6935754f-5 -0.00010349414 0.00014455433 7.147361f-5 -8.741043f-5 6.6630935f-5 2.1540744f-5 -6.268401f-5 -6.249639f-5 -8.668393f-6 5.39531f-5 -1.3646824f-5 0.00014352183 -0.00020002149 -3.2004868f-5 0.00012372916 0.00015520469 4.4085118f-5 -0.00016380136 -0.00012338438 -4.0826202f-5 -0.00019018409 -1.0502019f-5 -8.477703f-5 0.00015526694 -4.462281f-5; -9.947593f-5 1.5063005f-6 0.00014460056 -0.00012083987 0.00014957879 0.00013345339 1.4485001f-5 0.00014783573 4.112854f-5 -0.00010654259 0.00011303878 -0.0001258732 -7.460111f-5 -9.359642f-5 -6.426757f-5 -0.00012133081 8.922034f-5 5.8682097f-5 -2.7119657f-5 -2.8317334f-5 -5.4934867f-6 -2.46757f-5 4.936379f-5 -2.5232795f-5 0.00020656096 -6.964172f-5 8.010764f-5 -4.4639677f-5 -8.328375f-5 -5.2487623f-5 0.00017180633 -1.6355141f-5; 0.000138758 0.00019382889 -5.863604f-5 3.3305714f-5 9.172357f-5 7.284863f-5 0.00014184353 -0.0001686996 -0.00012629996 2.8608667f-5 -4.4426426f-5 -0.00012288899 -7.575053f-5 -2.2027069f-5 4.4801633f-5 -0.000120275436 -3.2849395f-5 8.926543f-5 -8.8487104f-5 -2.735039f-5 4.6418743f-5 -0.00011460029 -9.235032f-5 8.67862f-6 -1.8450117f-5 -0.00012753588 -7.8639954f-5 -5.0471986f-5 0.00013541848 -0.00015862811 -0.00015877957 -2.4194025f-5; -4.8660768f-5 -0.00013391151 -7.349683f-5 8.02818f-5 -0.00015986324 2.1024289f-5 4.9214403f-5 9.395384f-5 -3.9047495f-6 0.00010680931 4.6959554f-5 -6.885977f-5 1.3254657f-5 3.5777553f-5 1.9381634f-5 -2.6401435f-6 8.38674f-5 -0.00014774437 -0.00016102997 -1.24842945f-5 3.933436f-5 5.5315635f-5 -1.7530168f-5 3.6954436f-5 0.00024348218 -6.913201f-5 2.0428037f-5 4.0340165f-5 -2.4997637f-5 1.478078f-5 -0.00020621238 -0.00012701626; -9.660256f-5 -8.569874f-5 0.00015073585 -5.6945068f-5 -7.20865f-5 -7.954289f-5 -5.493525f-5 -0.00016404643 -9.4787625f-5 4.4568485f-5 0.00011648602 -0.00013451779 -0.00011554227 -0.00010831839 0.00013346571 3.6777525f-5 -6.7826666f-5 -9.958734f-5 -3.5779028f-6 -9.93983f-6 0.00016225304 -0.00013576419 -2.8019038f-5 1.2431868f-5 -6.635235f-6 -6.579719f-5 -0.00010876026 -2.3931681f-5 0.0001493431 8.9609347f-7 -6.459864f-5 2.2051316f-5; -6.142743f-5 -0.00021140835 7.427177f-5 2.222966f-5 -4.1423365f-5 -3.8283953f-5 1.360623f-5 -0.00012412127 3.284031f-5 -0.0001367559 5.5961904f-5 -0.00016105443 3.8492486f-5 6.379136f-5 2.4410394f-6 -2.1775204f-6 -7.8292345f-5 -0.00012222219 -1.4773377f-5 0.00022841519 0.00010553008 -4.5326728f-5 -6.27481f-5 0.00017479938 -0.00010043077 3.934118f-5 -5.4498723f-5 9.453292f-6 -3.9463845f-5 5.4951175f-5 0.00014109425 -5.4340244f-6; 8.825866f-7 9.379605f-6 -0.00010700233 8.2587394f-5 -4.4107004f-5 -4.4118253f-5 5.309249f-5 -6.0871793f-5 -3.1992997f-5 2.928864f-5 -0.00011599106 4.6508743f-5 -0.000109176195 -1.589789f-5 -2.7586017f-5 0.00010901788 9.6754f-5 -0.000207478 -2.3323606f-5 -7.844531f-5 -0.0002241967 8.522391f-6 9.250505f-5 6.040493f-5 0.00011749644 0.00011075076 0.00012912294 0.00017314746 -1.5519647f-5 -4.615071f-5 -0.00012067084 2.5620542f-5; -2.7755304f-5 -6.999636f-5 1.8877801f-5 -1.4189199f-5 8.9184934f-5 3.595105f-5 -0.0001522095 0.00014610187 -6.887941f-5 -4.932378f-5 -7.621989f-5 0.00017176325 -0.00010984087 0.00013584187 9.413452f-5 -5.5273846f-5 2.9186614f-5 -0.00011949047 -6.354327f-5 1.2808068f-5 -6.0990984f-5 -5.3717304f-5 5.3464315f-5 6.9071386f-5 -6.637186f-5 -6.6472294f-6 -3.066841f-5 3.6103615f-5 -0.00014115294 4.49174f-5 -5.7940706f-5 -9.97f-5; -4.6924306f-5 -3.575801f-5 -0.00012556184 -3.391472f-5 8.185405f-5 -5.1940795f-5 1.2575188f-5 -4.832944f-5 -4.3117572f-5 -9.074991f-5 -3.977878f-5 -8.309073f-5 -1.0714321f-5 0.00012701227 -0.00012915173 -6.1639534f-5 6.1865154f-5 -1.3553347f-5 -3.0433732f-5 -0.00013782103 0.00016800988 -7.07187f-5 2.949288f-8 -0.000120356526 6.5524306f-5 4.2722222f-5 -3.335502f-5 6.839124f-5 0.00020893024 1.9596659f-5 -7.570181f-5 -6.7985f-5; 9.848737f-5 -2.392576f-5 -0.00015034377 -3.9717356f-6 -0.00012080146 2.9074112f-5 -0.00010654375 -0.0001797384 0.0001591045 5.0553594f-5 -0.0001103808 -1.8568275f-5 -3.0598334f-5 -7.147367f-5 5.2602627f-5 7.808862f-5 4.180773f-5 -1.1131241f-5 7.59614f-5 8.8454675f-5 7.263702f-5 0.000119226206 5.0316277f-5 -4.704839f-5 -0.00012459057 0.00014102245 -6.8957896f-5 -0.00011361391 -9.058659f-5 -0.00019780148 -5.6960347f-5 -6.6476576f-5; 5.7772133f-5 -9.078852f-5 3.0791623f-5 -5.9679278f-5 -6.611829f-5 -0.00010625726 -9.632312f-5 -4.909664f-5 -4.2643667f-5 -5.143138f-5 4.1404797f-5 -5.2135685f-5 -0.00013624331 2.0565633f-6 -1.9270847f-5 0.00019906204 -0.00019169401 -9.514682f-5 -0.00011255626 7.439998f-5 0.00020648056 6.0953087f-5 3.1032487f-5 -0.00013159448 8.6210115f-5 -3.23723f-5 -1.0417006f-5 0.00012581065 -0.000186369 -8.1654f-6 0.00015782178 1.8765308f-5; -0.00017145938 0.00014257352 -6.63943f-5 -4.5507488f-5 0.00015409614 2.7156968f-5 -0.00016239316 0.00012629949 -0.00013424219 3.2155323f-5 1.0596837f-5 8.8069995f-5 0.00015023263 0.0001915779 -3.5178255f-5 -0.00010572186 -2.6710937f-5 0.00019697066 0.00021229453 -0.00011167061 -4.8576767f-5 6.146176f-5 2.4059573f-5 6.99716f-5 -0.00023017489 5.6888362f-5 0.00023506118 -6.0088914f-5 -2.6454247f-5 -2.7534914f-5 -7.8794794f-5 6.972333f-5; 0.00013471163 -2.2137554f-5 -6.0058497f-5 4.466842f-5 5.6800647f-5 -1.3217895f-5 -0.00014348046 3.9589348f-5 -0.00022157049 2.51182f-6 -0.00011027313 2.5294004f-5 7.5238975f-5 -5.6310873f-5 7.683867f-5 7.431382f-5 -8.521806f-5 -8.740854f-5 -1.634518f-5 6.3834435f-5 -4.009585f-5 9.313936f-5 7.184759f-5 -7.749128f-5 0.000107070526 -6.652815f-5 0.0001070305 8.189027f-5 9.693171f-5 -0.000101691214 0.00019646667 -5.796563f-7; 0.00021729007 5.0867224f-5 -9.619255f-5 -9.873551f-7 6.446504f-5 1.303398f-5 1.1135827f-5 0.00019197814 8.93165f-5 -8.2085855f-5 -0.00010658077 -6.536584f-5 -4.011337f-5 -0.00011690591 -5.981415f-6 -0.000108512184 -4.883026f-5 9.854732f-6 -0.00014579992 0.00024444592 -8.616298f-5 2.1912458f-5 1.6573707f-5 5.9382f-5 0.00012043838 -5.006534f-5 0.0002455571 2.7138938f-5 -8.443972f-6 -9.340865f-5 0.000113562615 3.207917f-5; 9.301772f-5 0.00013235677 -0.0001549189 -1.4700352f-6 -7.005735f-5 1.4939279f-6 -0.00013089488 -8.1744096f-5 0.00015944988 -6.63168f-5 -0.00013518886 -4.57937f-5 -3.079429f-5 4.7998354f-5 9.9412064f-5 7.693266f-5 8.3618295f-5 8.730748f-5 -5.3218657f-5 -9.584254f-5 9.128776f-5 0.00018133802 0.00012256826 4.178292f-5 -4.8912745f-5 0.0001228515 0.000114630784 2.9938721f-5 -3.039159f-5 9.4483315f-5 -7.922542f-5 7.438026f-5; -9.595736f-5 -0.00022084406 7.247254f-5 4.1949843f-5 -0.00015315785 -0.00022651452 -0.00016521312 0.00020725328 2.0725292f-5 -8.9046924f-5 -0.0002162433 4.9733404f-5 -0.00038460369 0.00018115916 -2.0690153f-5 3.4407976f-6 5.29355f-5 0.000112251895 0.00017864064 -1.0289447f-5 0.00013865551 0.00012592251 0.00018702744 6.237925f-5 0.00018104273 -5.384121f-5 7.997439f-5 3.9490005f-5 8.901868f-5 -6.039903f-5 -7.54671f-5 -0.00018633714; -0.00019335863 0.00012619802 -2.464479f-5 3.35384f-5 -8.6703905f-5 -0.00013406319 -2.9536448f-5 -9.3444005f-6 0.00012185372 -2.714295f-5 -6.825135f-5 -0.00011535594 -0.00011941291 0.00013883019 -3.8839855f-5 2.1380689f-5 0.00014749738 9.543163f-5 0.00016885837 0.00015049281 0.00015135514 -5.2961346f-5 -4.1714058f-5 7.11919f-5 0.000107373744 9.199155f-5 3.7997048f-5 0.00010370876 -9.4872084f-5 5.0598974f-5 -7.593949f-5 0.00010512711; -7.399502f-5 -0.00026247068 -5.601294f-5 -0.00014133962 -1.8535897f-5 4.2690044f-5 6.628586f-6 -0.00021238708 1.9753295f-5 0.00023486992 -4.2082065f-5 4.0196996f-5 -1.2395394f-6 -5.844502f-5 -9.2070775f-5 -0.00016446819 -3.455364f-5 0.00013924968 -5.7955003f-5 0.00013939988 -5.9499536f-5 -4.9564813f-5 -4.5096203f-6 2.5578707f-5 8.6171494f-5 3.453342f-5 -9.461335f-6 9.866274f-5 4.8781287f-5 -0.00013221746 6.11103f-5 9.748156f-5; -7.965214f-5 -1.5861397f-5 7.1270947f-6 7.197809f-5 0.00015599423 5.2392756f-5 3.6805133f-5 -6.34207f-5 -0.00012741618 8.5245665f-5 -2.2188713f-5 0.00014004845 -5.739394f-5 -6.5082036f-6 -0.0001495476 -0.000100811536 -9.5142525f-5 -2.1177228f-5 -5.2826068f-5 -4.7992773f-5 4.562157f-5 -1.4327982f-5 0.00015636723 1.617366f-5 -5.061588f-5 3.0520263f-5 7.0884955f-5 -0.00018241096 -2.8315999f-5 7.900629f-5 6.278875f-6 1.7987868f-5; -5.8718604f-5 0.00011179331 8.096355f-5 7.006408f-5 -1.3242688f-5 9.204574f-5 -5.2579257f-5 -5.083016f-5 -0.00017958722 3.8176437f-5 7.583376f-5 0.00019385175 0.00018152164 -1.9536365f-5 0.00015488907 -1.566161f-5 1.966738f-5 -2.3911323f-5 9.6550626f-5 5.5117787f-5 -0.00012023813 6.324298f-5 -0.0001485289 1.4616806f-5 5.769116f-6 5.1852225f-5 3.2225613f-5 -0.00012643255 -4.771313f-5 0.00012620464 6.671385f-5 9.244772f-5; -0.0002268267 5.0094095f-5 -3.674683f-5 0.00010526814 -0.00010284339 1.3732185f-5 -5.1300034f-5 9.453549f-5 -0.00015227614 -6.754226f-5 2.3986326f-5 3.2110472f-6 -9.232776f-5 -3.6716305f-5 0.00011593329 0.00015761089 -4.4642038f-5 -0.00011288082 9.283776f-5 -4.4072833f-5 1.3899358f-5 -6.0605504f-5 -4.2519667f-5 -0.00014545384 0.0002939735 0.00013688378 -0.000112266796 8.177227f-6 9.3419236f-5 -7.695847f-5 -1.8017323f-5 6.889727f-5; 9.812861f-5 -5.716916f-6 -3.4769605f-6 1.1166299f-5 1.5697246f-5 -7.946251f-5 5.556992f-5 4.1072548f-5 0.00019452357 3.3677137f-5 -6.5746965f-5 8.729643f-5 -3.587417f-5 -2.5361885f-5 4.8945392f-5 -4.2377596f-5 -0.00014625877 -9.549663f-5 -0.0002083573 -0.0001003506 -4.1377094f-5 -4.860768f-6 -5.779258f-5 -6.852436f-5 -2.9022045f-5 -5.770054f-5 5.569337f-6 -3.9808165f-5 -7.4234784f-5 0.00013304933 0.00013247601 -7.553607f-5; 0.00014356471 7.176553f-5 -0.00013460491 8.3037135f-5 -1.6215815f-6 -0.00017612234 3.0936488f-5 5.5220284f-5 -2.0805304f-5 -0.00020436272 -9.7442404f-5 -0.00013150772 0.00014083559 -0.00010797869 2.4472842f-5 8.445365f-5 -5.648099f-5 -0.0001082766 1.134091f-5 5.076903f-5 -0.00011323451 0.00018238046 -7.248719f-5 3.402577f-5 1.8596684f-5 -3.673104f-5 -1.04992105f-5 -6.649003f-5 -7.754755f-6 -0.00012555346 -6.903754f-5 -7.658849f-5], bias = Float32[0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0;;]), layer_4 = (weight = Float32[0.00022097296 -6.538917f-5 -0.000107125226 1.23211785f-5 -4.5818706f-5 -8.189573f-5 2.9421119f-5 -3.349178f-5 -8.818589f-5 1.3807584f-5 -2.4260306f-5 2.0350057f-5 2.7347907f-5 9.951691f-5 -7.862674f-5 -7.678361f-5 3.5268917f-5 1.0153902f-5 -0.0002696228 -4.278026f-6 6.868719f-5 4.1912986f-5 -0.00027113847 8.738122f-5 -2.2390101f-7 -2.6578378f-5 8.5153304f-5 0.0001626191 -4.1411735f-5 2.2202757f-5 -7.890124f-5 3.1969663f-5; 0.0001813832 0.00012095564 0.00015143862 -0.000111511275 1.0450687f-5 -0.00018076738 6.064111f-5 8.947477f-5 1.789411f-5 -0.00018079931 6.1545384f-6 -9.2438684f-5 -0.00010161966 -9.8071454f-5 6.6900386f-5 6.0208222f-5 1.1785471f-6 -2.6961962f-7 -8.976395f-5 -0.00010468097 -3.2005228f-5 4.6320212f-5 8.2225066f-5 -1.1371372f-5 4.7112506f-5 6.686672f-5 1.0326576f-6 -6.0779756f-5 -1.4772366f-5 4.0669847f-6 6.176199f-5 6.097681f-5], bias = Float32[0.0; 0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most DL frameworks, Lux defaults to using Float32, however, in this case we need Float64

julia
const params = ComponentArray{Float64}(ps)

const nn_model = StatefulLuxLayer(nn, st)
StatefulLuxLayer{true}(
    Chain(
        layer_1 = WrappedFunction(static(:direct_call), Base.Fix1{typeof(broadcast), typeof(cos)}(broadcast, cos)),
        layer_2 = Dense(1 => 32, cos),  # 64 parameters
        layer_3 = Dense(32 => 32, cos),  # 1_056 parameters
        layer_4 = Dense(32 => 2),       # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)

    axislegend(ax, [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"]; position=:lb)

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(waveform, pred_waveform), pred_waveform
end
loss (generic function with 1 method)

Warmup the loss function

julia
loss(params)
(0.000705466337911033, [-0.024258282170308704, -0.023473790466993164, -0.02268929876367787, -0.021363773435413326, -0.01946950173531184, -0.016966257616034323, -0.013801228835217923, -0.009906056827445074, -0.005199105918382931, 0.000414883660979762, 0.007035587369444872, 0.014746146209176835, 0.023565548897291677, 0.03332022520701497, 0.04336570832189008, 0.051928523926119095, 0.05471411347170794, 0.042636760791463196, 0.002151174800758711, -0.06612193403205877, -0.1103183852203226, -0.07645680857702872, -0.0069640754381139, 0.03876781637044641, 0.05426717990685296, 0.052951916521593825, 0.04484208818636801, 0.034820076241670876, 0.02492905919279689, 0.015930044680329103, 0.008038455678188408, 0.001253702812508349, -0.004504512507366092, -0.009336868148421053, -0.013341238876316544, -0.016601501473648375, -0.019188979566434065, -0.02115939697820815, -0.022555054337893177, -0.023405723565389176, -0.023729625033884764, -0.02353378175295329, -0.022815291802708568, -0.021559010822560006, -0.019739650361066452, -0.0173194487256972, -0.014247408189442192, -0.010458731114598222, -0.005874036159146701, -0.00040088062189276206, 0.0060587971061093265, 0.013592851411675187, 0.022230774821839622, 0.031840291441796516, 0.04187479747573307, 0.05079923564539124, 0.05484327051148609, 0.045825418834590965, 0.010468645675400125, -0.0552949760799388, -0.10820381329856793, -0.08596086199944115, -0.016808673985208498, 0.034162337657379124, 0.05345574513891443, 0.05384652844007196, 0.046295614238700156, 0.036336528363185155, 0.026321368807144994, 0.017144483669315466, 0.009067911202180428, 0.0021175813089516562, -0.0037913285870779284, -0.008750004281024947, -0.012867667494325571, -0.01622355790666165, -0.018898740514512593, -0.020945139910267713, -0.022412753594055708, -0.023329465487663632, -0.02371804509142951, -0.02358641068180372, -0.022932518320144857, -0.021745056145326066, -0.02000047957373044, -0.017660086250218043, -0.0146811739254148, -0.010994404826548, -0.006531094639290075, -0.0011933347559518483, 0.005107575657182254, 0.012469238093534424, 0.020924852330519925, 0.030382997941902613, 0.04037610547860026, 0.049583543032706026, 0.054699404306907645, 0.048395364403451775, 0.017947159976960843, -0.04428874267273546, -0.10406130435445339, -0.09431067573989338, -0.027253669625480034, 0.028765849653448257, 0.052226557200426224, 0.054586965450013895, 0.04771564297603509, 0.0378656832060003, 0.0277441607712672, 0.018386553062087745, 0.010129105127297446, 0.0030020300816063434, -0.0030536660252870372, -0.008149026004179993, -0.012376580791414286, -0.01583588486140641, -0.018595491007914265, -0.020723204420038197, -0.022260598019691238, -0.02324593619189107, -0.023698522894226557, -0.023630647059911452, -0.023042608386025492, -0.021922143379568836, -0.020250691781519906, -0.017990310157875198, -0.015100839132036126, -0.011515530162886266, -0.007168878263735761, -0.0019648505695605157, 0.004182635258115187, 0.011373334922981128, 0.019649717432883497, 0.028948018847038126, 0.038877591042281306, 0.04829856578859222, 0.05431879850596131, 0.050411443285134104, 0.02457286659521101, -0.03337161672454403, -0.09807648064058341, -0.1011855189825078, -0.038134424780102975, 0.022534338633186313, 0.05052604329488581, 0.05514262669785801, 0.04909059232051789, 0.0394059025798922, 0.02919014762538431, 0.019664645410787007, 0.011215294393564095, 0.003915482740984676, -0.0023007999997934373, -0.007524440625658445, -0.011873596876914376, -0.01543292715166004, -0.018284988141785842, -0.020488363815292843, -0.022101246738172757, -0.023153749779952788, -0.023671431599285975, -0.023667450328471538, -0.023143855450316417, -0.022091463490047546, -0.02048915207327733, -0.01831176437929662, -0.015506809397040102, -0.0120206400026877, -0.0077896788382846305, -0.0027166170678130223, 0.0032836264859726033, 0.010306823061528154, 0.018402309849933025, 0.027538044758360112, 0.03738395829231114, 0.046958417664770026, 0.05373973529943774, 0.051928560319408255, 0.030370086196155794, -0.022792086459521653, -0.09048417270995467, -0.10631514785474266, -0.04923942483569254, 0.015448271301020805, 0.04828706956898254, 0.05548368069589194, 0.050406382055708604, 0.040948304585649783, 0.030668782015354376, 0.020968772669521532, 0.012332463310636489, 0.004850909844148307, -0.001517768813220168, -0.006887119359954116, -0.011354157807790763, -0.015019191378289148, -0.017957801338103187, -0.020247368512526093, -0.021932451596321997, -0.02305374576688176, -0.023635985011765843, -0.02369658362553404, -0.023237185781319284, -0.022250289297309322, -0.020722732388349686, -0.018618845200889512, -0.015901227724818702, -0.012512103384303714, -0.008392126237045999, -0.0034466810538021984, 0.002407824381559784, 0.009266296737845656, 0.01718602367513923, 0.026154207919956513, 0.03589873806694502, 0.04557688926919312, 0.05299020306428377, 0.05300610179033838, 0.03536733826806634, -0.01272691509582713, -0.0815895497228537, -0.10949148182088825, -0.060296955919775824, 0.007492337478210942, 0.04545543063525977, 0.055569224329182175, 0.0516469621098202, 0.04249181264283251, 0.032168438463032335, 0.022306291428819007, 0.01347938327811533, 0.005815190087477647, -0.000716668876094642, -0.006228714612447393, -0.010818607444890528, -0.014590485835654449, -0.01762231129826768, -0.019994720086968152, -0.02175472258482869, -0.022945292663231028, -0.0235931056172061, -0.02371775823336329, -0.023322547522936485, -0.022401918624235642, -0.020944385584641123, -0.018916719299015554, -0.016283401199545477, -0.012989344511896685, -0.008977791316341296, -0.004966238120785782])

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l, pred_waveform)
    push!(losses, l)
    @printf "Training %10s Iteration: %5d %10s Loss: %.10f\n" "" length(losses) "" l
    return false
end
callback (generic function with 1 method)

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob, BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback, maxiters=1000)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-8.867039286997169e-5; 7.416553853545338e-5; -8.352342410944402e-5; 1.3414470231509767e-5; -0.00018793318304233253; -5.385300028137863e-5; 0.00017667203792370856; -3.992570054833777e-5; -3.1561274226987734e-5; 0.00011852990428451449; 8.573634113417938e-5; 8.974296360975131e-5; 4.237295070197433e-5; 5.088926263852045e-5; 6.988627865212038e-5; -0.00011852392344735563; -5.109255198476603e-6; 6.82497484376654e-5; -8.96060882951133e-5; -0.00011109934712294489; 4.097667260793969e-5; -8.304290531668812e-5; -8.961093953985255e-6; -2.357383527851198e-5; -0.00013556018529925495; 6.406412285286933e-5; -0.0001637011591810733; 9.301529644289985e-6; -0.0001256190298590809; -7.946415280457586e-5; -8.323340443894267e-5; 0.00015809274918865412;;], bias = [0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0;;]), layer_3 = (weight = [-2.466322621330619e-5 5.4696698498446494e-5 0.00014637160347774625 -3.448088318691589e-5 -4.037038888782263e-5 5.405099000199698e-5 7.297680713236332e-5 -0.00011844118125736713 0.00010473408474354073 -0.00011177561827935278 0.00020245990890543908 0.00020026895799674094 8.797423652140424e-5 3.5891618608729914e-5 6.781177944503725e-5 8.262758638011292e-5 0.0001261131401406601 1.4971125892770942e-5 -0.00022475456353276968 -0.00018247519619762897 4.237917892169207e-5 -0.00010268067853758112 0.00013310954091139138 6.364161890815012e-6 -6.0619287978624925e-5 2.3454644178855233e-5 -0.00011768659169320017 -5.419136505224742e-5 0.00015187183453235775 -0.0001490161521360278 -3.2116509828483686e-5 -0.00018643119256012142; 8.400931255891919e-5 8.190447988454252e-5 -7.579534576507285e-5 -0.0001449052506359294 -8.026669092942029e-5 0.00013214832870289683 6.33655654382892e-5 7.207032467704266e-5 1.4029108569957316e-5 2.152230626961682e-5 -7.369554077740759e-5 -0.00014524483412969857 -3.151979035465047e-5 -8.233093831222504e-5 -6.0942082200199366e-5 2.3419233912136406e-5 -3.539545286912471e-5 6.8046856540604495e-6 -5.9554971812758595e-5 4.3536409066291526e-5 -2.7638150640996173e-5 0.00014080459368415177 -3.6216013540979475e-5 -4.52483982371632e-5 -3.9989266952034086e-5 4.909185372525826e-5 9.576976481184829e-6 0.0001519581419415772 5.9358942962717265e-5 -9.657484770286828e-5 -3.135981387458742e-5 1.1793401426984929e-5; -5.253960262052715e-5 0.00011845507106045261 -0.0002430690365144983 9.69722168520093e-6 8.829695434542373e-5 0.00010465295781614259 2.872728873626329e-5 -5.513152791536413e-5 0.00012723248801194131 -0.00010711240611271933 -7.708885095780715e-5 5.56174672965426e-5 -1.9370278096175753e-5 -2.416076858935412e-5 0.00011528500181157142 -3.136871600872837e-5 0.0001466119138058275 -8.949744369601831e-5 0.00012131466792197898 -0.00013116146146785468 -7.79738838900812e-5 0.00017541114357300103 1.787727524060756e-5 -1.7823138477979228e-5 0.00017857323109637946 -2.911903357016854e-5 0.00017016495985444635 5.59481923119165e-5 0.00012793125642929226 -5.944788790657185e-5 0.00014498931705020368 9.690766091807745e-6; 2.9551620173151605e-5 0.00016055408923421055 -0.00014412394375540316 -3.269355875090696e-5 -6.627649418078363e-5 -3.4136766771553084e-5 2.472135929565411e-5 -4.641793202608824e-5 -0.00020297938317526132 -3.506650682538748e-5 1.0180168828810565e-5 -5.0856331654358655e-5 0.00015109706146176904 8.587551565142348e-5 0.00021773236221633852 -3.49271394952666e-5 -0.00013812026008963585 8.358050399692729e-5 9.752847108757123e-5 -2.225999651273014e-6 -8.101472485577688e-5 -8.812099986243993e-5 -5.8578618336468935e-5 -2.1236470274743624e-5 -3.2052554161055014e-5 -0.0001120085726142861 -0.00015449007332790643 4.754836845677346e-5 -8.56635015225038e-5 -3.266894054831937e-5 0.00012996479927096516 0.00015126800280995667; 1.216500186274061e-5 -1.990616738112294e-6 -4.478767095861258e-6 1.6926487660384737e-5 0.00016864626377355307 -0.00012150254042353481 1.8440421627019532e-5 -6.134338036645204e-5 0.00012013406376354396 -3.83633523597382e-5 3.089751407969743e-5 3.503678090055473e-5 6.470618245657533e-5 -2.8322860089247115e-5 -3.4378252166789025e-5 -1.901807809190359e-5 -6.267512071644887e-5 0.0001050620194291696 -4.059283310198225e-5 -9.230559226125479e-5 0.00017425604164600372 -0.00023758337192703038 7.999390618351754e-6 -2.167421371268574e-5 9.792498894967139e-5 -3.700004890561104e-5 -8.174924732884392e-5 -6.200884672580287e-5 0.000213896157220006 -5.589401553152129e-5 -0.00011051850015064701 -8.046602306421846e-5; -6.791505438741297e-5 -0.0001482118823332712 8.422407699981704e-5 -5.898953531868756e-5 6.697294884361327e-5 -4.299899228499271e-5 0.00017475972708780318 -0.00036095522227697074 -4.727284613181837e-5 6.996582669671625e-5 -8.023564441828057e-5 -0.00011565366730792448 -4.468608312890865e-5 6.091454270062968e-5 -1.9955672541982494e-5 6.283881521085277e-5 -0.00010737865522969514 -3.11860057990998e-5 -0.00018124704365618527 2.7625203074421734e-5 -0.00011038894444936886 -8.93720643944107e-5 4.411057670949958e-5 9.235368634108454e-5 7.072232983773574e-5 4.207093297736719e-5 6.544264033436775e-5 -5.849988156114705e-5 -0.00015023669402580708 5.6155611673602834e-5 -3.0056055038585328e-5 -7.770206138957292e-6; 1.9214807252865285e-5 -3.9501315768575296e-5 1.8342041130381403e-6 -0.00014592485968023539 0.00010553781612543389 -4.432902278495021e-5 5.0162281695520505e-5 1.3590230992122088e-5 8.224198245443404e-5 0.00013017692253924906 0.0001412126439390704 -0.0001568508450873196 5.003892147215083e-5 -0.00010419412865303457 -0.00013356495765037835 -8.686779437994119e-6 0.00013745338947046548 -1.8051629012916237e-5 -0.0001520366786280647 -0.0001936340268002823 3.4004395274678245e-5 2.3496219000662677e-5 0.00012532653636299074 0.00012756188516505063 -4.627462476491928e-5 7.178277883213013e-5 6.20294304098934e-5 9.534680316392041e-7 4.0353694203076884e-5 5.9009264077758417e-5 4.8922174755716696e-5 4.1693168896017596e-5; 6.092023249948397e-5 0.00010437741730129346 0.00015709483704995364 1.135179445554968e-5 8.456237992504612e-5 -0.0001171014373539947 8.596784755354747e-5 3.535802989063086e-6 -8.1651407526806e-5 -3.942572948290035e-5 -7.320776057895273e-5 -2.447974111419171e-6 1.5222952242766041e-5 3.9299935451708734e-5 -1.7043714251485653e-5 6.513080006698146e-5 -0.00016239701653830707 7.056157483020797e-5 -3.7803783925483e-5 -0.0001811208057915792 -0.00015309145965147763 -0.00022680692200083286 8.786711987340823e-5 5.837282515130937e-5 2.8744681912939996e-5 2.928445837824256e-6 0.00014017695502843708 -0.00011983780132140964 4.354081102064811e-5 -0.00010086141264764592 0.00012593631981872022 0.00011232370889047161; -4.659167825593613e-5 0.0002056625671684742 0.0001405050716130063 -1.8498887584428303e-5 -9.801278065424412e-5 -0.00019698055984918028 0.00013821669563185424 1.1619829820119776e-5 -4.1388113459106535e-5 -1.9489507394609973e-5 -0.00010723295417847112 4.699141936725937e-5 -0.00012979494931641966 4.284330861992203e-5 -0.0001146839204011485 7.25511199561879e-5 -4.5639113523066044e-5 0.00012155308650108054 7.581497629871592e-5 3.552370981196873e-5 0.00011717350571416318 9.649079584050924e-5 0.00017460749950259924 4.112009992240928e-5 0.00021681857469957322 -6.2102240008243825e-6 3.366549572092481e-5 2.197298817918636e-5 3.065612690988928e-5 -8.914562204154208e-5 3.421819201321341e-5 5.926754965912551e-6; 0.00023464004334528 2.9943514164187945e-5 2.427689469186589e-5 -0.00010158290388062596 8.745663217268884e-5 0.0001173319251392968 5.693575440091081e-5 -0.00010349413787480444 0.00014455433120019734 7.147360884118825e-5 -8.741042984183878e-5 6.663093518000096e-5 2.154074354621116e-5 -6.26840119366534e-5 -6.249638681765646e-5 -8.668393093103077e-6 5.3953099268255755e-5 -1.3646824299939908e-5 0.00014352182915899903 -0.00020002148812636733 -3.2004867534851655e-5 0.00012372915807645768 0.00015520468878094107 4.4085118133807555e-5 -0.00016380136366933584 -0.00012338437954895198 -4.082620216649957e-5 -0.00019018408784177154 -1.0502018994884565e-5 -8.477702795062214e-5 0.00015526694187428802 -4.462281140149571e-5; -9.947593207471073e-5 1.5063004639159772e-6 0.00014460056263487786 -0.00012083986803190783 0.0001495787873864174 0.00013345338811632246 1.4485000974673312e-5 0.00014783572987653315 4.1128540033241734e-5 -0.00010654258949216455 0.00011303878272883594 -0.00012587319361045957 -7.460110646206886e-5 -9.359641990158707e-5 -6.426757317967713e-5 -0.00012133081327192485 8.922033885028213e-5 5.868209700565785e-5 -2.7119656806462444e-5 -2.8317333999439143e-5 -5.493486696650507e-6 -2.4675700842635706e-5 4.936378900310956e-5 -2.523279545130208e-5 0.00020656095875892788 -6.964171916479245e-5 8.010763849597424e-5 -4.463967707124539e-5 -8.32837467896752e-5 -5.248762317933142e-5 0.00017180632858071476 -1.635514126974158e-5; 0.00013875799777451903 0.00019382889149710536 -5.863604019396007e-5 3.330571416881867e-5 9.172356658382341e-5 7.284862658707425e-5 0.00014184352767188102 -0.00016869959654286504 -0.00012629995762836188 2.8608666980289854e-5 -4.4426426029531285e-5 -0.0001228889886988327 -7.57505331421271e-5 -2.202706855314318e-5 4.4801632611779496e-5 -0.00012027543561998755 -3.284939521108754e-5 8.926542795961723e-5 -8.84871042217128e-5 -2.7350390155334026e-5 4.641874329536222e-5 -0.00011460029054433107 -9.235031757270917e-5 8.678620361024514e-6 -1.845011684054043e-5 -0.00012753588089253753 -7.8639954153914e-5 -5.047198646934703e-5 0.00013541847874876112 -0.00015862811414990574 -0.00015877957048360258 -2.419402517261915e-5; -4.866076778853312e-5 -0.0001339115115115419 -7.349682709900662e-5 8.028180309338495e-5 -0.00015986323705874383 2.1024288798798807e-5 4.921440267935395e-5 9.39538367674686e-5 -3.904749519278994e-6 0.00010680931154638529 4.695955431088805e-5 -6.885977199999616e-5 1.3254656550998334e-5 3.577755342121236e-5 1.938163404702209e-5 -2.6401435206935275e-6 8.386740228161216e-5 -0.00014774437295272946 -0.0001610299659660086 -1.2484294529713225e-5 3.933436164516024e-5 5.531563510885462e-5 -1.7530168406665325e-5 3.695443592732772e-5 0.00024348217993974686 -6.913200923008844e-5 2.04280368052423e-5 4.034016455989331e-5 -2.4997636501211673e-5 1.478077956562629e-5 -0.00020621238218154758 -0.000127016261103563; -9.660256182542071e-5 -8.569873898522928e-5 0.00015073585382197052 -5.6945067626656964e-5 -7.208649913081899e-5 -7.954289321787655e-5 -5.493525168276392e-5 -0.00016404643247369677 -9.478762513026595e-5 4.456848546396941e-5 0.00011648602230707183 -0.00013451778795570135 -0.00011554227239685133 -0.00010831838881131262 0.00013346571358852088 3.677752465591766e-5 -6.782666605431587e-5 -9.95873415376991e-5 -3.5779028166871285e-6 -9.93983030639356e-6 0.0001622530398890376 -0.00013576418859884143 -2.8019037927151658e-5 1.2431867617124226e-5 -6.63523496768903e-6 -6.57971904729493e-5 -0.00010876025771722198 -2.3931681425892748e-5 0.00014934310456737876 8.960934678725607e-7 -6.459863652708009e-5 2.2051315681892447e-5; -6.142743222881109e-5 -0.00021140834724064916 7.427176751662046e-5 2.2229660316952504e-5 -4.1423365473747253e-5 -3.8283953472273424e-5 1.3606229913420975e-5 -0.00012412127398420125 3.284031117800623e-5 -0.00013675590162165463 5.596190385404043e-5 -0.0001610544277355075 3.849248605547473e-5 6.379136175382882e-5 2.441039441691828e-6 -2.1775204004370607e-6 -7.829234527889639e-5 -0.00012222219083923846 -1.477337718824856e-5 0.00022841518511995673 0.00010553008178249002 -4.532672755885869e-5 -6.274809857131913e-5 0.00017479938105680048 -0.00010043077054433525 3.934117921744473e-5 -5.449872332974337e-5 9.453291568206623e-6 -3.946384458686225e-5 5.495117511600256e-5 0.0001410942495567724 -5.434024387795944e-6; 8.825866188999498e-7 9.379605216963682e-6 -0.00010700232814997435 8.258739399025217e-5 -4.410700421431102e-5 -4.4118252844782546e-5 5.3092491725692526e-5 -6.087179281166755e-5 -3.1992996810004115e-5 2.9288639780133963e-5 -0.00011599106073845178 4.650874325307086e-5 -0.00010917619511019439 -1.5897889170446433e-5 -2.7586016585701145e-5 0.00010901787754846737 9.675400360720232e-5 -0.00020747800590470433 -2.3323606001213193e-5 -7.844530773581937e-5 -0.000224196701310575 8.522390999132767e-6 9.25050480873324e-5 6.040493099135347e-5 0.0001174964418169111 0.00011075076326960698 0.00012912294187117368 0.0001731474621919915 -1.5519646694883704e-5 -4.615071156877093e-5 -0.00012067084026057273 2.5620542146498337e-5; -2.7755304472520947e-5 -6.999635661486536e-5 1.887780126708094e-5 -1.4189198736858089e-5 8.918493404053152e-5 3.5951048630522564e-5 -0.00015220949717331678 0.00014610186917707324 -6.887940980959684e-5 -4.9323778512189165e-5 -7.621989061590284e-5 0.0001717632549116388 -0.00010984086839016527 0.00013584186672233045 9.413452062290162e-5 -5.527384564629756e-5 2.918661448347848e-5 -0.00011949046893278137 -6.354327342705801e-5 1.2808068277081475e-5 -6.099098391132429e-5 -5.3717303671874106e-5 5.346431498765014e-5 6.907138595124707e-5 -6.637186015723273e-5 -6.647229383816011e-6 -3.066840872634202e-5 3.610361454775557e-5 -0.00014115293743088841 4.491740037337877e-5 -5.794070602860302e-5 -9.970000246539712e-5; -4.692430593422614e-5 -3.5758010199060664e-5 -0.00012556184083223343 -3.3914719097083434e-5 8.185405022231862e-5 -5.194079494685866e-5 1.2575187611219008e-5 -4.8329438868677244e-5 -4.3117572204209864e-5 -9.074991248780861e-5 -3.9778780774213374e-5 -8.309073018608615e-5 -1.0714321433624718e-5 0.00012701227387879044 -0.0001291517255594954 -6.163953366922215e-5 6.186515383888036e-5 -1.3553347343986388e-5 -3.0433731808443554e-5 -0.00013782102905679494 0.00016800987941678613 -7.071869913488626e-5 2.949287924991495e-8 -0.00012035652616759762 6.552430568262935e-5 4.2722222133306786e-5 -3.335501969559118e-5 6.839123670943081e-5 0.00020893024338874966 1.959665860340465e-5 -7.570181332994252e-5 -6.798499816795811e-5; 9.848736954154447e-5 -2.3925760615384206e-5 -0.0001503437670180574 -3.9717356230539735e-6 -0.00012080145825166255 2.907411180785857e-5 -0.00010654374636942521 -0.0001797384029487148 0.0001591045001987368 5.055359360994771e-5 -0.00011038080265279859 -1.8568274754215963e-5 -3.059833397855982e-5 -7.147366704884917e-5 5.2602626965381205e-5 7.808861846569926e-5 4.1807728848652914e-5 -1.1131241080875043e-5 7.596139766974375e-5 8.845467527862638e-5 7.263701991178095e-5 0.00011922620615223423 5.03162773384247e-5 -4.7048390115378425e-5 -0.00012459057325031608 0.00014102245040703565 -6.895789556438103e-5 -0.00011361391079844907 -9.058658906724304e-5 -0.00019780147704295814 -5.696034713764675e-5 -6.647657573921606e-5; 5.777213300461881e-5 -9.078851871890947e-5 3.079162343055941e-5 -5.967927791061811e-5 -6.611829303437844e-5 -0.00010625726281432435 -9.632312139729038e-5 -4.9096641305368394e-5 -4.2643667256925255e-5 -5.143137968843803e-5 4.140479722991586e-5 -5.213568510953337e-5 -0.0001362433104077354 2.0565632894431474e-6 -1.9270846678409725e-5 0.00019906203669961542 -0.00019169400911778212 -9.514681732980534e-5 -0.00011255626304773614 7.439997716573998e-5 0.00020648055942729115 6.095308708609082e-5 3.103248673141934e-5 -0.0001315944828093052 8.621011511422694e-5 -3.237229975638911e-5 -1.0417005796625745e-5 0.00012581064947880805 -0.00018636899767443538 -8.165399776771665e-6 0.0001578217779751867 1.876530768640805e-5; -0.00017145938181784004 0.00014257352449931204 -6.639429921051487e-5 -4.550748781184666e-5 0.0001540961384307593 2.7156967917107977e-5 -0.00016239316028077155 0.00012629949196707457 -0.0001342421892331913 3.215532342437655e-5 1.0596836546028499e-5 8.80699953995645e-5 0.00015023263404145837 0.00019157789938617498 -3.5178254620404914e-5 -0.00010572186147328466 -2.6710937163443305e-5 0.000196970664546825 0.00021229452977422625 -0.00011167061165906489 -4.8576766857877374e-5 6.14617601968348e-5 2.4059572751866654e-5 6.99715965311043e-5 -0.00023017489002086222 5.6888362450990826e-5 0.00023506117577198893 -6.008891432429664e-5 -2.6454246835783124e-5 -2.7534913897397928e-5 -7.879479380790144e-5 6.972333358135074e-5; 0.00013471163401845843 -2.2137553969514556e-5 -6.0058497183490545e-5 4.466842074180022e-5 5.680064714397304e-5 -1.3217894775152672e-5 -0.00014348045806400478 3.95893475797493e-5 -0.00022157048806548119 2.5118199573626043e-6 -0.0001102731330320239 2.52940044447314e-5 7.523897511418909e-5 -5.6310873333131894e-5 7.683866715524346e-5 7.431382255163044e-5 -8.521806012140587e-5 -8.740853809285909e-5 -1.6345180483767763e-5 6.383443542290479e-5 -4.009584881714545e-5 9.31393587961793e-5 7.184759306255728e-5 -7.749127689749002e-5 0.00010707052570069209 -6.652814772678539e-5 0.00010703050065785646 8.189026993932202e-5 9.693171159597114e-5 -0.0001016912137856707 0.00019646667351480573 -5.796562732029997e-7; 0.00021729007130488753 5.086722376290709e-5 -9.619254706194624e-5 -9.873550652628182e-7 6.446504266932607e-5 1.3033980394538958e-5 1.1135826753161382e-5 0.00019197813526261598 8.931649790611118e-5 -8.208585495594889e-5 -0.00010658076644176617 -6.536584260175005e-5 -4.0113369323080406e-5 -0.00011690591054502875 -5.981415142741753e-6 -0.00010851218394236639 -4.8830261221155524e-5 9.854731615632772e-6 -0.00014579991693608463 0.00024444592418149114 -8.616298146080226e-5 2.1912457668804564e-5 1.6573707398492843e-5 5.938200047239661e-5 0.00012043838069075719 -5.006534047424793e-5 0.00024555710842832923 2.713893809414003e-5 -8.4439716374618e-6 -9.340864926343784e-5 0.00011356261529726908 3.2079169614007697e-5; 9.301771933678538e-5 0.00013235677033662796 -0.00015491890371777117 -1.4700351584906457e-6 -7.005735096754506e-5 1.4939279253667337e-6 -0.00013089488493278623 -8.174409595085308e-5 0.00015944987535476685 -6.63168029859662e-5 -0.00013518886407837272 -4.5793698518536985e-5 -3.079429006902501e-5 4.799835369340144e-5 9.941206371877342e-5 7.693265797570348e-5 8.361829532077536e-5 8.73074823175557e-5 -5.3218656830722466e-5 -9.584253712091595e-5 9.128775855060667e-5 0.00018133802223019302 0.00012256826448719949 4.1782921471167356e-5 -4.8912745114648715e-5 0.00012285150296520442 0.00011463078408269212 2.9938721127109602e-5 -3.0391589461942203e-5 9.448331547901034e-5 -7.922542135929689e-5 7.438025932060555e-5; -9.595735900802538e-5 -0.0002208440564572811 7.247253961395472e-5 4.1949842852773145e-5 -0.00015315784548874944 -0.00022651451581623405 -0.00016521311772521585 0.00020725328067783266 2.0725292415590957e-5 -8.904692367650568e-5 -0.0002162432938348502 4.973340401193127e-5 -0.000384603685233742 0.00018115916464012116 -2.0690153178293258e-5 3.440797627263237e-6 5.293549838825129e-5 0.00011225189518881962 0.00017864063556771725 -1.0289447345712688e-5 0.00013865550863556564 0.0001259225100511685 0.00018702744273468852 6.237925117602572e-5 0.000181042734766379 -5.384120959206484e-5 7.997438660822809e-5 3.949000529246405e-5 8.901867840904742e-5 -6.0399030189728364e-5 -7.546709821326658e-5 -0.00018633714353200048; -0.00019335863180458546 0.00012619802146218717 -2.464479075570125e-5 3.3538399293320253e-5 -8.670390525367111e-5 -0.00013406318612396717 -2.953644798253663e-5 -9.344400496047456e-6 0.00012185372179374099 -2.7142950784764253e-5 -6.825134914834052e-5 -0.00011535594239830971 -0.00011941290722461417 0.00013883018982596695 -3.8839854823891073e-5 2.1380688849603757e-5 0.0001474973832955584 9.543162741465494e-5 0.00016885837248992175 0.00015049280773382634 0.00015135513967834413 -5.2961346227675676e-5 -4.171405817032792e-5 7.11918983142823e-5 0.00010737374395830557 9.19915473787114e-5 3.7997047911630943e-5 0.00010370875679655 -9.487208444625139e-5 5.059897375758737e-5 -7.593948976136744e-5 0.00010512711014598608; -7.399501919280738e-5 -0.0002624706830829382 -5.601294105872512e-5 -0.00014133962395135313 -1.8535896742832847e-5 4.269004421075806e-5 6.628586106671719e-6 -0.00021238707995507866 1.975329541892279e-5 0.0002348699199501425 -4.208206519251689e-5 4.019699554191902e-5 -1.2395394151099026e-6 -5.8445020840736106e-5 -9.207077528117225e-5 -0.00016446819063276052 -3.4553639125078917e-5 0.00013924967788625509 -5.795500328531489e-5 0.0001393998827552423 -5.949953629169613e-5 -4.9564812798053026e-5 -4.509620339376852e-6 2.5578707209206186e-5 8.617149433121085e-5 3.45334192388691e-5 -9.461335139349103e-6 9.866274194791913e-5 4.8781286750454456e-5 -0.0001322174648521468 6.111030234023929e-5 9.748156298883259e-5; -7.965214172145352e-5 -1.5861396605032496e-5 7.127094704628689e-6 7.197808736236766e-5 0.00015599423204548657 5.239275560597889e-5 3.680513327708468e-5 -6.342070264508948e-5 -0.000127416176837869 8.524566510459408e-5 -2.2188713046489283e-5 0.00014004844706505537 -5.7393939641769975e-5 -6.5082035689556506e-6 -0.000149547602632083 -0.00010081153595820069 -9.514252451481298e-5 -2.117722760885954e-5 -5.282606798573397e-5 -4.799277303391136e-5 4.562157118925825e-5 -1.4327982171380427e-5 0.00015636722673662007 1.617366069694981e-5 -5.061587944510393e-5 3.0520262953359634e-5 7.088495476637036e-5 -0.00018241096404381096 -2.831599886121694e-5 7.900629134383053e-5 6.278874934650958e-6 1.7987867977353744e-5; -5.871860412298702e-5 0.00011179331340827048 8.096355304587632e-5 7.006408122833818e-5 -1.3242687600723002e-5 9.204573871102184e-5 -5.257925658952445e-5 -5.083016003482044e-5 -0.0001795872231014073 3.8176436646608636e-5 7.583376282127574e-5 0.00019385175255592912 0.00018152163829654455 -1.9536364561645314e-5 0.00015488907229155302 -1.566160972288344e-5 1.9667379092425108e-5 -2.3911323296488263e-5 9.655062603997067e-5 5.5117787269409746e-5 -0.00012023813178529963 6.324298010440543e-5 -0.00014852889580652118 1.4616805856348947e-5 5.769115887233056e-6 5.1852224714821205e-5 3.2225612812908366e-5 -0.00012643255467992276 -4.7713128878967836e-5 0.00012620464258361608 6.67138519929722e-5 9.244772081729025e-5; -0.0002268266980536282 5.0094095058739185e-5 -3.674683102872223e-5 0.0001052681400324218 -0.0001028433907777071 1.3732184925174806e-5 -5.130003410158679e-5 9.453549137106165e-5 -0.0001522761449450627 -6.754226342309266e-5 2.398632568656467e-5 3.2110472147905966e-6 -9.23277621041052e-5 -3.671630474855192e-5 0.00011593328963499516 0.0001576108916196972 -4.464203811949119e-5 -0.00011288082168903202 9.283776307711378e-5 -4.4072832679376006e-5 1.3899358236812986e-5 -6.0605503676924855e-5 -4.2519666749285534e-5 -0.0001454538432881236 0.00029397351318039 0.00013688378385268152 -0.00011226679635001346 8.17722684587352e-6 9.341923578176647e-5 -7.69584730733186e-5 -1.8017322872765362e-5 6.889727228553966e-5; 9.812860662350431e-5 -5.71691589357215e-6 -3.476960500847781e-6 1.1166299373144284e-5 1.5697245544288307e-5 -7.946250843815506e-5 5.556991891353391e-5 4.1072547901421785e-5 0.00019452357082627714 3.36771372531075e-5 -6.574696453753859e-5 8.729643013793975e-5 -3.58741708623711e-5 -2.5361885491292924e-5 4.8945392336463556e-5 -4.237759640091099e-5 -0.0001462587679270655 -9.549663081998006e-5 -0.0002083573053823784 -0.00010035059676738456 -4.1377094021299854e-5 -4.8607680582790636e-6 -5.7792578445514664e-5 -6.852435762993991e-5 -2.9022045055171475e-5 -5.770054121967405e-5 5.5693371905363165e-6 -3.9808164729038253e-5 -7.423478382406756e-5 0.0001330493250861764 0.0001324760087300092 -7.553606701549143e-5; 0.00014356471365317702 7.17655275366269e-5 -0.00013460491027217358 8.303713548230007e-5 -1.6215815321629634e-6 -0.000176122339325957 3.0936487746657804e-5 5.522028368432075e-5 -2.080530430248473e-5 -0.00020436271734070033 -9.744240378495306e-5 -0.00013150772429071367 0.0001408355892635882 -0.00010797868890222162 2.447284168738406e-5 8.445364801445976e-5 -5.648098886013031e-5 -0.00010827660298673436 1.1340909622958861e-5 5.0769031076924875e-5 -0.00011323451326461509 0.0001823804632294923 -7.248719339258969e-5 3.402576840016991e-5 1.8596683730720542e-5 -3.673103856272064e-5 -1.049921047524549e-5 -6.649002898484468e-5 -7.754754733468872e-6 -0.0001255534589290619 -6.903753819642588e-5 -7.658849062863737e-5], bias = [0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0;;]), layer_4 = (weight = [0.00022097295732237399 -6.538916932186112e-5 -0.00010712522635003552 1.2321178473939653e-5 -4.581870598485693e-5 -8.189572690753266e-5 2.942111859738361e-5 -3.349177859490737e-5 -8.818588685244322e-5 1.3807583854941186e-5 -2.4260305508505553e-5 2.035005672951229e-5 2.7347907234798186e-5 9.951691026799381e-5 -7.862674101488665e-5 -7.678360998397693e-5 3.5268916690256447e-5 1.0153901712328661e-5 -0.0002696228038985282 -4.278026153770043e-6 6.868719356134534e-5 4.19129864894785e-5 -0.000271138473181054 8.738121687201783e-5 -2.239010115090423e-7 -2.65783783106599e-5 8.515330409863964e-5 0.00016261909331660718 -4.141173485550098e-5 2.220275746367406e-5 -7.890124106779695e-5 3.196966281393543e-5; 0.0001813832059269771 0.00012095564306946471 0.00015143862401600927 -0.00011151127546327189 1.0450687113916501e-5 -0.00018076738342642784 6.064110857550986e-5 8.947477181209251e-5 1.7894109987537377e-5 -0.00018079931032843888 6.1545383687189315e-6 -9.243868407793343e-5 -0.00010161966201849282 -9.80714539764449e-5 6.690038571832702e-5 6.020822183927521e-5 1.178547108793282e-6 -2.696196190754563e-7 -8.976394747151062e-5 -0.0001046809702529572 -3.200522769475356e-5 4.632021227735095e-5 8.222506585298106e-5 -1.1371372238500044e-5 4.711250585387461e-5 6.686671986244619e-5 1.0326575647923164e-6 -6.077975558582693e-5 -1.4772365830140188e-5 4.066984729433898e-6 6.176198803586885e-5 6.097681034589186e-5], bias = [0.0; 0.0;;]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(compute_waveform(
    dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(ax, tsteps, waveform_nn_trained; marker=:circle,
        alpha=0.5, strokewidth=2, markersize=12)

    axislegend(ax, [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb)

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(LuxDeviceUtils)
    if @isdefined(CUDA) && LuxDeviceUtils.functional(LuxCUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && LuxDeviceUtils.functional(LuxAMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 4 default, 0 interactive, 2 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 4
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.5, artifact installation
CUDA driver 12.5
NVIDIA driver 555.42.6

CUDA libraries: 
- CUBLAS: 12.5.3
- CURAND: 10.3.6
- CUFFT: 11.2.3
- CUSOLVER: 11.6.3
- CUSPARSE: 12.5.1
- CUPTI: 2024.2.1 (API 23.0.0)
- NVML: 12.0.0+555.42.6

Julia packages: 
- CUDA: 5.4.3
- CUDA_Driver_jll: 0.9.2+0
- CUDA_Runtime_jll: 0.14.1+0

Toolchain:
- Julia: 1.10.5
- LLVM: 15.0.7

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.735 GiB / 4.750 GiB available)

This page was generated using Literate.jl.