Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux, ComponentArrays, LineSearches, LuxAMDGPU, LuxCUDA, OrdinaryDiffEq, Optimization,
      OptimizationOptimJL, Printf, Random, SciMLSensitivity
using CairoMakie

CUDA.allowscalar(false)

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end
one2two (generic function with 1 method)

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end
soln2orbit (generic function with 2 methods)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end
d_dt (generic function with 1 method)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0)<1e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio1 "mass_ratio must be <= 1"
    @assert mass_ratio0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end
compute_waveform (generic function with 2 methods)

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defiing a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(Base.Fix1(broadcast, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1e-4)),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1e-4)),
    Dense(32 => 2; init_weight=truncated_normal(; std=1e-4)))
ps, st = Lux.setup(Xoshiro(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[0.0001253726; 0.00012816812; 9.224394f-5; 3.0486752f-5; -2.2661845f-5; 3.115078f-5; 3.5190256f-5; 4.2731914f-5; -2.2489096f-6; -4.7470952f-5; -9.278142f-5; 1.5630816f-5; -3.5392943f-5; 8.884549f-6; -1.15579805f-5; 7.044403f-5; 4.5001932f-5; 6.961939f-5; -0.00012714528; -9.0974274f-5; -5.063351f-5; -5.2669282f-5; 8.388948f-5; 0.0002985441; 0.00033257937; -1.7594622f-5; 0.00011912561; 9.642335f-5; 8.0435675f-5; -4.2967335f-5; -0.00020805448; -2.7180793f-5;;], bias = Float32[0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0;;]), layer_3 = (weight = Float32[-6.008906f-5 4.542796f-5 -2.2868518f-5 7.046881f-5 9.843544f-5 -9.165628f-5 5.5358596f-5 -3.279566f-5 -0.00012260441 -0.00012328335 0.00013304957 -6.420584f-5 0.00020756888 8.387777f-5 8.929312f-5 -0.00014829448 6.813939f-5 -0.000122051446 3.660188f-5 8.523472f-5 -9.224378f-5 5.96306f-5 0.0002010029 3.6622234f-6 -0.0002254078 -5.3598276f-5 3.8156595f-5 0.00012327447 0.00010974734 0.00014456858 7.448392f-5 2.615327f-5; 9.9618504f-5 -0.00016609205 5.2010022f-5 -0.00010961037 0.00011962865 -4.580462f-5 4.1270185f-5 -0.000105193336 -0.000120230674 -0.00015014368 2.3368286f-6 -2.342843f-5 0.00026303608 -4.342555f-6 0.00015543099 -1.791166f-5 1.7722734f-5 1.2043515f-5 -0.00013169622 -3.69776f-5 0.0002769712 -3.8565595f-5 0.00011689869 -0.00010036015 -0.000106303916 -1.7187049f-5 0.00016364001 -0.00012242825 0.00021932689 -0.00015334293 5.1045172f-5 7.3968324f-5; -0.00015246811 7.9321144f-5 -7.899363f-5 1.1222935f-5 -3.2371983f-5 1.195506f-5 0.000105951105 7.203449f-5 1.0281939f-5 0.00013498557 -1.6547292f-5 -7.106412f-5 -4.703192f-5 3.8583345f-5 3.382777f-5 -1.2923209f-5 0.00027907823 0.00013168868 7.1761515f-6 -3.3376393f-5 -9.309273f-5 4.5351182f-5 -5.414618f-6 -9.302012f-5 4.6208956f-5 -0.00011257594 0.00010082199 2.5308134f-5 0.0001805303 -3.4213062f-5 2.799851f-5 9.491753f-5; -1.1064164f-5 0.00016366516 -0.00010079857 8.956571f-5 -1.2111361f-5 6.63934f-5 -5.4554945f-5 -8.497785f-5 -0.00012545125 7.844014f-5 -0.00010753696 -0.00011047121 3.747697f-5 -0.00012538407 -3.2302465f-5 -4.713638f-5 -0.00015517141 -3.9433562f-5 0.0001663785 8.6438675f-5 1.2926645f-5 -2.9991194f-5 3.0193216f-5 -2.9400073f-5 8.4165935f-5 -6.472843f-5 -0.0001716227 1.985605f-5 7.547925f-6 -2.5223928f-5 0.00010060235 6.107933f-7; 0.00022673054 0.00013736103 -5.163585f-5 -0.00010295694 -7.6411656f-5 0.00016375634 5.765735f-5 -1.3810919f-5 0.00010411896 7.236314f-5 0.00016682842 0.00011338982 0.00012786838 7.375667f-5 1.1673065f-5 9.900223f-6 -8.3475476f-5 0.00011639862 7.748621f-5 -7.974854f-6 8.9547495f-5 -0.00018833633 -1.053091f-5 0.0001612253 -2.0482166f-5 6.482124f-5 -1.1799855f-5 4.2259195f-5 4.4753782f-5 -0.00013092076 3.8651517f-5 0.00018375363; 3.3519653f-5 -0.00011991507 9.456128f-6 -0.00019213167 -7.054388f-5 -6.093023f-5 -5.966628f-6 0.00026953855 6.223933f-5 -0.00021491769 -7.108303f-5 -4.0439063f-5 4.023458f-5 -6.449317f-5 6.2640254f-5 -5.0696854f-5 6.5258873f-6 -2.3640003f-5 -3.87994f-5 8.313636f-6 -0.00026224714 -2.8569698f-6 -0.00018214982 5.654069f-5 -0.00012389262 3.898388f-5 -0.00016264776 2.3244838f-5 -1.8775283f-5 -5.5495526f-5 0.00011597751 0.00012948683; 5.4806937f-5 -0.00020170622 -2.1998536f-5 -5.541724f-5 -8.079806f-5 5.2412448f-5 7.994055f-6 -0.00020841783 0.00015121426 -2.3936639f-6 0.00013245014 -6.8931484f-5 -0.00012224365 -0.00016351925 -3.688163f-5 4.1378884f-5 -1.525489f-5 -8.350892f-6 2.4456967f-5 0.00019978803 4.0826842f-5 4.0829f-5 -1.8928435f-5 -0.00016185036 -8.299632f-5 0.00011545076 -7.734352f-5 0.00021520001 1.0582127f-5 -0.0001397139 -1.2805794f-5 -2.3401082f-5; 1.4455803f-5 0.00012871885 -0.00012271664 3.728695f-6 -0.0002257178 3.832365f-5 -8.613317f-5 -2.8309476f-6 5.2713585f-5 -1.6331772f-6 5.456693f-6 -3.5744277f-5 6.183518f-6 -0.00022681536 -1.2221644f-5 -0.00020841193 -0.00011963833 -4.4052653f-5 -9.770669f-5 4.586166f-6 1.2798217f-5 -4.8955317f-6 -0.00017193689 -4.179448f-5 0.0001330998 0.00016837414 7.677278f-5 0.00012481489 -4.447257f-5 -2.5691541f-5 7.4335905f-5 -0.00013993717; -2.1552582f-5 5.149736f-5 2.5171614f-5 -2.6010131f-5 -7.987964f-7 0.000103567756 -0.0001332783 5.568864f-7 9.045554f-5 4.7005604f-5 1.9392308f-5 8.470403f-5 -1.3164753f-6 2.8361086f-5 0.00011057365 7.4914184f-5 0.00010263546 -0.00014887896 6.0948525f-5 3.203961f-5 -3.2321765f-5 0.00010823574 2.915375f-5 3.664396f-5 0.00017048609 -9.930068f-5 0.00013144541 0.000101691774 0.00021992212 7.9332196f-5 -1.4978954f-5 0.00013611102; 8.0591075f-5 0.0001228908 0.00010454233 -6.501862f-5 0.00014180416 0.00014183909 -0.00029739612 0.00013653954 -0.00011930413 -0.00012400991 -7.527758f-6 4.0497827f-5 -0.00012410582 4.2417723f-5 1.2337588f-5 0.00016942849 -5.84643f-6 6.2785607f-6 -5.1296458f-5 6.1170413f-6 4.1718286f-5 -9.4514806f-5 0.00013690014 1.0089253f-5 -7.328188f-5 -5.3973396f-5 -0.00012933837 8.831326f-5 0.00016792379 0.000102444385 5.9017937f-5 2.8701319f-5; 5.782917f-5 -2.6016807f-5 -7.6065495f-5 1.213212f-6 -0.00016574403 9.797648f-5 4.96952f-5 -3.220772f-5 0.000106005704 -0.00014204881 5.426169f-6 -0.000116579395 1.8192462f-6 -6.3326304f-5 -7.842911f-5 5.8523416f-5 0.00011972597 -0.00013036211 -7.439924f-5 -4.256977f-5 6.138781f-5 -0.0001052631 -6.1268006f-6 -8.1579055f-5 5.0316958f-5 8.685121f-5 0.00011510628 -0.00018137516 9.651078f-5 5.4611475f-5 -4.6939272f-7 0.00012406068; 1.4062471f-5 8.991071f-5 -9.3064715f-5 -7.3755225f-5 5.038496f-6 -9.159249f-5 -2.1776f-5 -7.9484f-5 -6.280654f-5 -0.00012931575 7.249739f-5 0.00013251665 -5.7224137f-5 6.1359795f-5 0.0003050777 -0.00017147271 6.92256f-5 -3.7431182f-5 9.503915f-5 0.00015108872 -3.0084302f-5 -1.7660208f-5 -6.7344986f-6 2.0656818f-5 8.4723855f-5 -6.1043844f-5 5.235908f-5 6.826914f-5 8.238448f-5 -8.113788f-6 0.00012745566 1.7185746f-5; -0.00014151195 1.6905267f-5 3.917412f-5 -9.5064715f-5 8.558023f-5 -8.687761f-5 -0.00013078317 0.00016486124 0.00026561768 -0.000108975255 0.00011941096 -9.406498f-5 0.00026794625 0.00010374207 1.1254705f-5 1.9628125f-5 -3.473407f-5 -0.00010944869 -6.558816f-5 3.6184112f-5 -7.864561f-5 -0.00011309643 -2.3387109f-5 -9.458307f-5 2.1303347f-5 -0.00014693098 5.66448f-5 -0.0001564281 -6.49064f-5 -1.4508405f-5 -2.6732983f-5 0.00011253191; 6.750816f-5 0.0002885414 -4.15977f-5 -0.000102519654 5.4724693f-5 5.38286f-5 2.5124442f-5 0.00011710336 4.4870852f-5 -7.2133753f-6 -7.587472f-5 -8.055374f-6 -0.0001730003 -5.1844027f-6 -3.630573f-6 -4.9484406f-6 0.00021220745 -6.7657656f-6 5.1075982f-5 -0.00018997864 -0.00020105689 8.991581f-5 2.9718847f-5 6.1081796f-6 -3.4397493f-5 -2.9210907f-5 0.00010406153 -0.00013104727 -3.802992f-5 -7.0752285f-5 -4.0166662f-5 5.180139f-5; 0.00013664953 0.000104940045 -3.4072204f-5 6.2963685f-5 -0.00012016572 -6.2303676f-5 -1.4324798f-5 5.4774628f-5 -5.0690924f-5 -4.771149f-5 -0.00012610939 3.548493f-5 -0.000102360966 6.210718f-6 -6.809151f-5 -5.082593f-5 6.0168542f-5 -2.6699045f-5 -9.523829f-5 -8.3329305f-6 -6.898066f-5 -9.9613644f-5 4.437343f-5 -4.3095064f-5 -1.2355805f-5 0.00011344149 -0.00011146349 -0.00016323027 0.00015002134 -2.598923f-6 7.2110524f-5 7.293142f-5; 5.417247f-5 0.00011744604 -3.479022f-5 -5.6088425f-5 -6.0949966f-5 2.2351092f-5 3.571905f-5 0.00011195463 3.4165256f-5 -5.7408703f-5 1.4962062f-5 0.0001443794 8.8650675f-5 8.22244f-6 -0.00011324365 -1.1001302f-5 0.00023044301 0.00012356139 -6.548126f-5 0.00015889565 7.803961f-5 -0.00013597954 -3.1522755f-5 -7.4552117f-6 -2.623296f-5 9.0982954f-5 5.640537f-6 -6.557017f-5 0.00020550904 -8.1545135f-5 -0.00010877734 4.3319087f-5; 5.6748213f-5 0.00010620347 -6.820981f-5 0.00012445347 3.6930427f-7 -4.9318587f-5 -7.550813f-5 -8.4988555f-5 0.00026575883 -4.1344105f-5 -1.6208542f-5 -1.7878077f-5 2.041742f-5 0.00039854468 -1.7590086f-5 8.950063f-5 0.00010741805 -0.00011408617 -0.00017899726 -3.0934203f-5 -2.5899213f-5 -3.0687166f-5 -8.2491824f-5 -1.7970586f-5 -7.701214f-6 -0.00011960249 -3.7614715f-5 1.4305311f-5 -2.0053515f-6 5.10296f-5 -4.4048273f-5 5.1965537f-5; 1.28438005f-5 0.00017369552 4.468856f-5 -5.3192947f-5 4.8592177f-5 -3.121244f-5 4.7394158f-5 0.00022708147 -3.0317384f-5 0.00014156358 -3.1710104f-5 0.00010998828 -0.00011377353 8.176469f-5 4.759487f-5 9.927664f-5 -5.390298f-5 1.71931f-5 -8.499132f-5 6.1199453f-6 8.029985f-5 4.323649f-5 -6.382598f-5 -2.44911f-6 -0.00019673927 5.229466f-5 0.00019560267 0.00022473201 -2.5662881f-5 -0.0001554867 3.9395516f-5 -7.327031f-5; -3.633786f-5 2.1308398f-7 -0.0001614739 -0.00025207706 -7.4886906f-5 6.558857f-5 -4.9722486f-5 8.472175f-6 0.00010889393 -9.666002f-5 -0.00012636161 3.0172434f-5 -4.2931388f-5 5.179524f-5 0.0002517364 3.6612822f-5 -2.1386375f-5 -1.9245246f-5 1.5963562f-5 -0.0001100648 -1.682372f-5 -8.286563f-5 -0.00011720453 6.633552f-5 -0.00012875858 0.00021093183 -0.00021814842 6.273354f-5 -8.957385f-5 1.7042383f-5 -0.000115992676 2.897984f-6; -3.4016586f-5 -2.59013f-5 2.6169904f-5 4.758647f-5 -0.00010826929 0.000101587095 8.1904705f-5 -0.00011583823 -1.701631f-5 3.3881566f-5 6.1679953f-6 -5.227397f-5 0.00015272366 8.2071f-5 -7.883658f-5 9.8618526f-5 -5.0695307f-5 -9.610622f-5 0.000216721 0.00014913545 -1.6586917f-5 4.2735846f-6 8.367928f-5 -0.0001520336 -0.0002938041 -5.8389596f-5 -6.7307315f-5 -4.8881833f-5 -2.5224144f-5 -4.7806272f-5 0.000220458 7.925562f-5; -0.00016665948 -0.00020686787 -0.00023723063 0.0001588202 -2.9479348f-5 -0.00015747423 0.00014253949 -7.0361486f-5 1.7354276f-5 3.9647057f-5 -1.2852146f-5 8.755829f-5 -5.382071f-5 -7.147556f-5 -4.1502582f-5 -6.5295746f-5 -6.342053f-5 1.4325055f-5 1.5416601f-5 -1.5754227f-5 1.1163248f-5 -4.3991928f-5 4.6620742f-5 -0.00012975807 -0.00012868924 -7.5687945f-5 6.9953516f-5 -0.00011088165 -0.00013837084 0.0001526886 -0.000112755384 -9.746957f-5; -8.170751f-6 -9.629996f-5 -2.7500819f-5 7.579724f-5 4.2426145f-6 0.00020490936 5.4837616f-5 3.256761f-5 -5.2053678f-5 0.00013074794 -3.7855443f-5 -8.709057f-5 2.2095759f-5 -6.01192f-5 2.1646543f-5 0.000100279045 0.0001500651 0.00010989723 1.4232606f-5 6.499191f-9 -0.00010460386 -1.3345401f-6 7.089724f-6 -2.4665384f-5 -7.4520816f-5 -0.00018406239 4.8593403f-5 -6.732744f-5 -0.00012263248 3.8064616f-5 0.00028555363 -8.676727f-5; -3.0978637f-5 -0.00012353972 -4.0634757f-5 -1.6098573f-5 -0.00015040195 -1.9052553f-5 7.36669f-5 -9.545437f-5 9.394282f-5 -3.6537374f-5 6.42978f-5 4.0423158f-5 -0.00015851014 -0.00022510502 -6.387985f-5 7.674027f-5 0.00012353002 0.0001181441 0.00015544529 1.7926084f-5 -0.00010915 4.665437f-5 -9.0042064f-5 0.00017885752 0.00014476129 2.5468406f-5 9.8898134f-5 -7.679519f-5 7.9269645f-5 3.9533143f-6 -8.980645f-5 -0.00012060348; 8.654732f-6 2.8645976f-5 0.00018625277 -6.890072f-5 4.1658637f-5 -7.914293f-5 6.318184f-5 0.00012346014 -5.0885184f-5 -4.4475637f-5 7.084644f-5 -0.0001427258 -3.7978916f-5 -0.00018405296 -1.5561484f-6 0.00019219244 0.00014829986 -1.2194113f-5 5.626547f-5 0.0001881784 1.5527889f-5 0.00018812461 -0.00014725806 1.0738726f-5 -2.3847637f-5 -2.6792464f-5 -2.3869788f-5 -8.281088f-5 -0.00017648992 2.8756469f-5 -0.00018526317 5.360901f-5; -0.0001946946 8.7902255f-5 0.00026421924 -2.514182f-5 0.00022039372 -2.7098367f-6 0.00014758136 7.9404585f-5 7.65535f-5 -0.00012748245 -7.369717f-5 5.9902144f-5 -8.967087f-5 -0.00013016663 3.2495223f-5 8.002463f-5 8.9877794f-5 2.0900677f-6 1.942559f-5 3.702201f-5 -7.957776f-5 1.9427935f-5 -8.4512585f-6 -7.5404187f-6 -3.7303238f-5 -1.619316f-5 -0.00019354945 -5.381439f-5 0.00012591967 -0.00014563135 -1.2258373f-5 -5.1737723f-5; 0.00018326213 0.00018729946 4.873298f-5 0.00012693008 6.72333f-5 -7.82017f-6 -6.152589f-5 1.0626375f-5 0.00012542088 -2.3230346f-6 -1.9315934f-5 -0.00021650067 -0.0001683108 0.00010278978 9.218735f-5 -1.600312f-5 -3.4672008f-5 1.1356758f-5 -7.759856f-5 -6.454152f-5 -3.3659046f-5 0.00012797663 -1.9253954f-5 0.00012409604 0.0001430867 5.004762f-6 -7.778902f-5 2.7476508f-6 9.838457f-6 0.00011546546 7.700415f-5 -2.684548f-6; -5.5837f-5 1.8642246f-5 0.00016949518 -7.916467f-5 -3.724935f-5 3.67531f-5 -8.4292944f-5 0.00011122766 0.00026308506 -5.3231037f-5 6.375765f-5 0.000108870474 3.6365458f-5 -8.483376f-6 -4.805406f-5 -5.7667374f-5 7.361467f-5 0.00018777508 8.2519124f-5 -0.00016758971 -3.843943f-5 -0.00010196889 9.225055f-5 2.7459828f-5 -5.499765f-5 8.0889484f-5 6.931203f-5 -6.0589977f-5 -0.00021018714 1.6718564f-5 -5.0319315f-5 4.7834797f-5; -0.0001209295 0.00020836045 0.000115584466 -4.84983f-5 -3.337979f-5 1.787035f-5 1.9116502f-5 3.3849938f-5 -0.00012697584 -9.3498165f-5 0.00011695487 5.2109783f-5 -0.00018208846 -3.2333475f-5 -3.8429906f-5 -0.00015594842 -0.00011758128 -0.00012012199 3.7609938f-5 -0.00021085436 -6.417347f-5 -3.4417168f-5 6.367152f-5 6.1598774f-5 -8.084452f-5 -8.927126f-5 1.1185033f-5 -5.669734f-5 4.991897f-5 7.198059f-5 -8.825196f-6 0.00025767004; -5.9606795f-5 -2.9043911f-5 0.00014432435 3.0763954f-6 9.51451f-5 -7.1388085f-6 -1.3357036f-5 0.0001523067 -0.00019427466 -6.923415f-6 -0.00014516209 -1.1386863f-5 0.00010630589 -8.322884f-5 -7.283596f-5 2.6924417f-5 8.9736866f-5 -0.00019368176 0.00013873399 -2.8289158f-5 -0.00013414279 -4.2938253f-5 -4.5269597f-5 -4.307836f-5 6.4523025f-5 9.99884f-6 0.00017860034 -3.4798f-5 9.300363f-5 8.425746f-5 8.637702f-5 9.4499956f-5; -4.7352754f-5 8.1819635f-5 3.0226603f-5 -0.00017700979 -3.38456f-5 2.0846132f-6 -0.00016612203 0.00014089802 0.00010119764 -6.112183f-5 -5.549169f-5 -5.5340937f-5 2.9084307f-5 -7.4798f-5 -0.000117221396 -7.120909f-5 -0.00018337826 -0.00020002383 -8.744016f-5 0.000112411464 0.0001632129 -4.2701042f-5 0.00019768042 2.9053448f-5 0.0001107459 -9.001838f-5 2.0468191f-5 3.815621f-5 -8.937415f-5 -4.4978857f-5 -4.223956f-5 -5.63189f-5; -8.498253f-5 0.00013502229 -6.780409f-5 5.6289748f-5 -2.3252805f-5 -0.00010345472 -0.00012106566 3.340658f-5 -0.000119074895 0.00010782463 0.00024830512 -1.7187884f-5 6.622863f-5 1.6379143f-5 -0.00014495595 3.9969905f-6 4.7439047f-5 0.00014280966 -1.6384669f-5 -4.8975853f-5 4.6943754f-7 -0.00013702351 0.00019797032 9.642844f-5 -0.00016748835 0.00011347465 -5.9059712f-5 -8.7085944f-5 -3.054113f-5 -2.633095f-5 8.187274f-5 0.0001193916; 8.503744f-5 4.7714973f-5 -0.00017844726 0.00016355743 0.00011235684 1.752668f-5 -0.00020055025 5.805073f-5 -0.00010292699 -5.146567f-6 -6.27873f-5 4.2212887f-5 -8.898079f-6 0.00015697314 -5.115096f-5 7.65575f-5 4.6795267f-5 2.0881082f-5 -9.916777f-5 -4.6552082f-5 0.00014374575 0.00017625398 -7.4865275f-5 -0.00010843208 5.226777f-5 -0.00018888562 -6.002323f-5 -0.00011650161 -5.7355577f-5 0.00015991942 -2.0403645f-6 0.00020783405], bias = Float32[0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0;;]), layer_4 = (weight = Float32[0.00025235894 3.3736436f-5 2.0235639f-5 4.218651f-5 -0.0001447382 0.00010682502 1.9351692f-5 -6.5573804f-5 0.00017227058 0.00030142968 -7.9128215f-5 4.5542904f-5 -0.00010868347 -0.00013034919 -7.243883f-5 4.9205388f-5 2.2522678f-5 0.00025145255 0.00013512254 0.00010957698 0.00020686888 -7.5906806f-5 0.00024555618 -0.00010865807 0.00011146277 -0.00013444619 -9.2567876f-5 0.00013364285 0.00013104294 -2.1013217f-5 0.00014538053 7.19304f-5; -1.5717394f-5 -7.936949f-5 0.000102745194 3.3879332f-5 1.4059f-5 0.00014087299 -3.5583928f-6 1.2046585f-5 -3.2315893f-5 -8.7277855f-5 1.1423255f-5 -9.718666f-5 3.3375131f-6 9.53629f-5 5.9623537f-5 0.00010772769 -2.1164316f-5 -5.130476f-5 -2.2857552f-5 8.1975915f-5 4.5331f-5 3.9157876f-5 9.322063f-5 3.2646563f-5 6.411039f-5 -2.2393127f-5 -3.3563687f-5 -0.00018222499 4.9808426f-5 -4.0787756f-5 4.153631f-5 8.7279106f-5], bias = Float32[0.0; 0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most DL frameworks, Lux defaults to using Float32, however, in this case we need Float64

julia
const params = ComponentArray{Float64}(ps)

const nn_model = StatefulLuxLayer(nn, st)
Lux.StatefulLuxLayer{true, Lux.Chain{@NamedTuple{layer_1::Lux.WrappedFunction{Base.Fix1{typeof(broadcast), typeof(cos)}}, layer_2::Lux.Dense{true, typeof(cos), PartialFunctions.PartialFunction{nothing, nothing, typeof(WeightInitializers.truncated_normal), Tuple{}, @NamedTuple{std::Float64}}, typeof(WeightInitializers.zeros32)}, layer_3::Lux.Dense{true, typeof(cos), PartialFunctions.PartialFunction{nothing, nothing, typeof(WeightInitializers.truncated_normal), Tuple{}, @NamedTuple{std::Float64}}, typeof(WeightInitializers.zeros32)}, layer_4::Lux.Dense{true, typeof(identity), PartialFunctions.PartialFunction{nothing, nothing, typeof(WeightInitializers.truncated_normal), Tuple{}, @NamedTuple{std::Float64}}, typeof(WeightInitializers.zeros32)}}}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}, layer_4::@NamedTuple{}}}(Chain(), nothing, (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()), nothing)

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)

    axislegend(ax, [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"]; position=:lb)

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    loss = sum(abs2, waveform .- pred_waveform)
    return loss, pred_waveform
end
loss (generic function with 1 method)

Warmup the loss function

julia
loss(params)
(0.19864267104895475, [-0.024301811372489362, -0.02351452012844848, -0.022727228884407903, -0.02139688511538445, -0.019495536819405333, -0.016982600789421685, -0.013804773294192019, -0.0098930380059661, -0.005164929224601059, 0.0004758148614656758, 0.007129936216298874, 0.014881344621774489, 0.023748518717072422, 0.03355335369489713, 0.04363591239829475, 0.05217925789960898, 0.0547822006761726, 0.04215863348164113, 0.0006806040797772885, -0.06804636428782442, -0.11049675294048208, -0.07470574335830009, -0.005574624756931064, 0.03908467600707997, 0.05403529651411343, 0.05261888233971014, 0.044590671923407654, 0.03469122205281989, 0.024910229256385724, 0.015994002929522615, 0.00815757859059801, 0.0014043229124783886, -0.004341354581206901, -0.009175862036034382, -0.013193531743302851, -0.01647536239928888, -0.019090364184423246, -0.02109235822837772, -0.022522036397832096, -0.02340774787129988, -0.023766381457660014, -0.02360363719004896, -0.022915198099864333, -0.02168436092683319, -0.019883978379160328, -0.017474054425124858, -0.01440080980597618, -0.010595967478225419, -0.005975779177649145, -0.00044241752606841725, 0.006108455377838981, 0.013771177352305269, 0.022579182805187795, 0.032394157155558746, 0.0426331419069734, 0.051643800600613424, 0.055348135443862814, 0.044941021823385253, 0.006679323441789741, -0.06110250132155704, -0.10964485545264617, -0.08099830578514754, -0.012066680942600452, 0.0357158735233505, 0.053099146271229136, 0.052955865922199676, 0.04549384825604509, 0.03580470801667653, 0.02606312944526685, 0.01710749884370114, 0.009190408482811843, 0.002343809325161589, -0.003507184425076506, -0.008444140750909258, -0.012567998029348585, -0.01595123962268338, -0.018669527625882573, -0.020770317457504288, -0.02230001867404025, -0.023283287509648115, -0.023740147114047237, -0.02367571426555948, -0.023085276672685985, -0.02195446018647958, -0.020256358934579766, -0.017948230200202368, -0.014982384918975264, -0.011283340107600298, -0.006774660723411938, -0.0013488802412472962, 0.005094116287126373, 0.012664153348920537, 0.02140370939161869, 0.0312165168710654, 0.041585504990269605, 0.05101852832059182, 0.055743899435715984, 0.04743906354323166, 0.012396504978652106, -0.05394930223943055, -0.1079475392997044, -0.08685971552127861, -0.018753604469617782, 0.03204649542491366, 0.05196343531365937, 0.05318389097492307, 0.046340881970318326, 0.03689063831721967, 0.027207806166568496, 0.01821738254533854, 0.010232017791007478, 0.003288494320873175, -0.0026576523938012923, -0.00770232046861803, -0.01192493595773973, -0.015414272516710964, -0.018230119522723697, -0.0204333179727814, -0.022059535217463434, -0.02314222451060772, -0.02369618067529792, -0.023729715955485346, -0.02323888574611272, -0.022207187343386444, -0.020611028487103684, -0.01840678360377403, -0.015547189364507592, -0.011956804835656958, -0.0075591955542905105, -0.002244704122423525, 0.004089046919832806, 0.011559902053090982, 0.020225513695889062, 0.030020577297801974, 0.040498301097545875, 0.05030961093649584, 0.055975925281412264, 0.04966454613531651, 0.017807482356409356, -0.04664983234900535, -0.10543620499907917, -0.09221095283890503, -0.025599737790367033, 0.028072758259250652, 0.0506213200317278, 0.053294963327731766, 0.047127747904354964, 0.03794870957518663, 0.02833609967013095, 0.01933043919911632, 0.011274205929525837, 0.0042453893100039275, -0.001803129891538751, -0.006941840709215736, -0.011270723759420087, -0.0148594349358089, -0.0177784366996663, -0.02007635772208997, -0.021803549374501244, -0.022983218725913617, -0.023635031175185447, -0.023766423529078564, -0.02337426723258539, -0.022443456512921872, -0.02094655761837132, -0.01885086376279347, -0.016095032084506248, -0.012614149228614098, -0.00833068696485442, -0.003129938765558055, 0.0030943200518179777, 0.010461726663930504, 0.019042974930492593, 0.028809901610172352, 0.03937437968234371, 0.049521943113798976, 0.05605742689927047, 0.05162046388073492, 0.02290827373604027, -0.03928681614307046, -0.10213642349394113, -0.09698333713868743, -0.03256505029189903, 0.02380009865317201, 0.04905773933907662, 0.053285027755947165, 0.04784976182646066, 0.03897208682579719, 0.029456660083375923, 0.020435351981456774, 0.012321295990949237, 0.0052061443495273405, -0.0009298864632687454, -0.006174521342280191, -0.010601549111510001, -0.014291922480523139, -0.01730534564951829, -0.019706679536838252, -0.021530003259840558, -0.022807379112472746, -0.023555711430399855, -0.023785772709885262, -0.02349216777435523, -0.02266040757024243, -0.021269346295391494, -0.019274531446382727, -0.016627482425809238, -0.013256988614654849, -0.009086866938742258, -0.004001468224248271, 0.0021086125109075877, 0.00936768542945659, 0.01786121802178096, 0.027586515806346558, 0.03821619520135642, 0.04866229870175904, 0.055993987073191656, 0.053320466465224786, 0.02768109989333342, -0.03191441568183548, -0.09809851535325192, -0.10111666057501922, -0.03959156886953213, 0.019223038347228764, 0.04727046703551348, 0.05314370914395251, 0.04850225784077086, 0.039962434156996345, 0.03055762413289705, 0.021537909062536588, 0.013370575034054654, 0.0061761703777820235, -5.1199305176733116e-5, -0.00539299795807367, -0.009918584032970612, -0.013708088019119482, -0.016819882065498896, -0.019319055354470465, -0.021239681196065514, -0.02261412080804898, -0.02345948010044314, -0.02378730712791756, -0.023592452208444598, -0.022861061657088753, -0.021572221174310673, -0.01968245377189342, -0.01714329115748125, -0.013884056114055245, -0.009828421054927927, -0.005772785995800633])

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l, pred_waveform)
    push!(losses, l)
    @printf "Training %10s Iteration: %5d %10s Loss: %.10f\n" "" length(losses) "" l
    return false
end
callback (generic function with 1 method)

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob, BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback, maxiters=1000)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [0.00012537259317444162; 0.00012816811795307047; 9.22439430722229e-5; 3.0486751711562943e-5; -2.2661844923232005e-5; 3.1150779250259925e-5; 3.519025631250707e-5; 4.273191370882255e-5; -2.248909595437594e-6; -4.7470952267686876e-5; -9.278141806119823e-5; 1.5630816051253696e-5; -3.5392942663713706e-5; 8.884549060867313e-6; -1.1557980542407164e-5; 7.04440317348282e-5; 4.500193244880667e-5; 6.961938925077733e-5; -0.00012714527838390962; -9.097427391675057e-5; -5.0633509090363894e-5; -5.26692820130563e-5; 8.388947753691855e-5; 0.0002985440951302215; 0.0003325793659312793; -1.7594622477211623e-5; 0.00011912560876201741; 9.642334771329893e-5; 8.043567504474917e-5; -4.296733459342975e-5; -0.00020805448002623152; -2.718079304026375e-5;;], bias = [2.8485316523204684e-16; 4.457065010862317e-16; 1.8863621531236157e-16; 1.730847975261434e-17; -5.163021705283189e-17; 5.774999843132338e-17; -1.9302695684418845e-17; 2.968654431222899e-17; 6.214176948776639e-19; -8.153019127146352e-17; -1.727852444826892e-16; 1.958353355789674e-17; -3.883101008649781e-17; 3.12118860947463e-17; -1.5356471381661898e-17; 1.3966000875126044e-16; 1.5199114721598273e-16; 1.1280563820209809e-16; -7.916616910487952e-17; -5.4991425667915867e-17; -4.4012733211068195e-17; -3.194459511521703e-17; 1.0680437025740113e-16; 4.801230177061413e-16; 1.9401606965645811e-16; 2.3911189193043943e-17; 2.751325025433783e-16; 1.434531368254366e-16; 2.9494425988063646e-16; 1.8555626587073053e-17; -2.799689715914708e-16; -5.921358249728335e-17;;]), layer_3 = (weight = [-6.0085992761794164e-5 4.5431028676777694e-5 -2.2865451417646365e-5 7.047187345229657e-5 9.843850811372188e-5 -9.165321144613805e-5 5.5361663082256685e-5 -3.279259154445946e-5 -0.00012260134436295643 -0.00012328027851985337 0.00013305263955039708 -6.420276959765067e-5 0.00020757194969697458 8.388083760288263e-5 8.929618733596703e-5 -0.000148291411922442 6.814245379424178e-5 -0.00012204837886024089 3.660494648278679e-5 8.523778914600451e-5 -9.224071591965449e-5 5.963366883584682e-5 0.0002010059654748622 3.6652905163125244e-6 -0.00022540473192581036 -5.359520919959364e-5 3.815966219185103e-5 0.00012327753601489323 0.00010975041031265649 0.0001445716446068744 7.448698697036225e-5 2.6156336654784428e-5; 9.961958245550749e-5 -0.00016609097541366263 5.201110003692818e-5 -0.00010960929533755353 0.00011962972463518451 -4.5803541739563875e-5 4.1271262729447754e-5 -0.00010519225792083866 -0.00012022959593739274 -0.00015014260019231582 2.3379066206173104e-6 -2.3427352731208474e-5 0.000263037161188573 -4.341476811508516e-6 0.0001554320636060097 -1.791058160595218e-5 1.7723811892233947e-5 1.2044593189807235e-5 -0.00013169513725732002 -3.697652094704303e-5 0.0002769722789380699 -3.8564516886230674e-5 0.00011689977078642536 -0.00010035907210841087 -0.00010630283826323292 -1.7185970806143423e-5 0.00016364109002463107 -0.00012242716796752304 0.00021932796631883686 -0.00015334185330718734 5.104624984961434e-5 7.396940169567302e-5; -0.00015246487170821234 7.93243859602113e-5 -7.899038907503124e-5 1.1226177529771735e-5 -3.236874114967382e-5 1.1958301665310161e-5 0.00010595434717239306 7.20377326883505e-5 1.0285180660014479e-5 0.00013498881592518972 -1.6544049931824042e-5 -7.106087584410263e-5 -4.7028678882760416e-5 3.858658667874139e-5 3.383101350615342e-5 -1.2919967294276852e-5 0.00027908147306285205 0.00013169191945913293 7.179393587142766e-6 -3.337315071853199e-5 -9.308948508127309e-5 4.5354424154956654e-5 -5.41137594204469e-6 -9.30168783003329e-5 4.6212197711914864e-5 -0.00011257269513456558 0.00010082523361183686 2.5311376456967923e-5 0.0001805335457260801 -3.4209820360537745e-5 2.8001752734246396e-5 9.492077218006844e-5; -1.1066074784414986e-5 0.00016366324698601365 -0.00010080048095852179 8.95637959735603e-5 -1.2113271531519316e-5 6.639148623497644e-5 -5.4556855411022293e-5 -8.497976241352497e-5 -0.00012545315703329273 7.843823104911983e-5 -0.00010753886899629643 -0.00011047311718294954 3.747505898053261e-5 -0.00012538598539259854 -3.2304375872003004e-5 -4.713828902759854e-5 -0.00015517331930760741 -3.943547280806389e-5 0.00016637658254757957 8.643676401397154e-5 1.2924733801526524e-5 -2.9993104271226855e-5 3.0191305177769194e-5 -2.9401983646723947e-5 8.416402406579614e-5 -6.473034309175418e-5 -0.00017162460871923406 1.9854139373539995e-5 7.546014422305776e-6 -2.522583863475354e-5 0.00010060044142002501 6.088825222711143e-7; 0.00022673726552145364 0.00013736775498816395 -5.162912212564142e-5 -0.00010295021284854696 -7.640492923109423e-5 0.00016376306656730628 5.76640776589627e-5 -1.380419244832745e-5 0.00010412568725861159 7.236986483895205e-5 0.00016683515049498043 0.00011339654510357846 0.0001278751041272769 7.376339989697094e-5 1.1679791797187528e-5 9.906949245304266e-6 -8.346874902508186e-5 0.00011640535002788872 7.749293935464405e-5 -7.96812774661642e-6 8.955422153212978e-5 -0.00018832960277813618 -1.052418347990745e-5 0.0001612320228476114 -2.0475439768467992e-5 6.482796369801689e-5 -1.1793128677730865e-5 4.226592126632403e-5 4.4760508800204705e-5 -0.00013091403171437715 3.865824318464156e-5 0.00018376035212987424; 3.3515117278784205e-5 -0.00011991960750101433 9.451592772427682e-6 -0.00019213620802995546 -7.054841521105589e-5 -6.0934765175010785e-5 -5.971163632897507e-6 0.00026953401279959226 6.223479172538099e-5 -0.00021492222262688512 -7.10875636702549e-5 -4.0443598523530094e-5 4.0230043991210576e-5 -6.449770703122876e-5 6.263571881779553e-5 -5.070138908715119e-5 6.521351819426207e-6 -2.364453834681455e-5 -3.880393600929865e-5 8.309100079975414e-6 -0.0002622516720709626 -2.8615052958425554e-6 -0.00018215435088287865 5.6536156220699e-5 -0.00012389715524982058 3.897934325711913e-5 -0.00016265229609935245 2.3240302793318838e-5 -1.877981853405389e-5 -5.550006139467131e-5 0.00011597297739567186 0.00012948229155360894; 5.4804867480162226e-5 -0.00020170829061211251 -2.2000604937460784e-5 -5.5419309271064615e-5 -8.080013044204875e-5 5.241037893315497e-5 7.991985560829925e-6 -0.0002084198958499193 0.00015121219353491099 -2.3957329147823175e-6 0.00013244806637263158 -6.893355289034875e-5 -0.00012224572396630682 -0.00016352131474091077 -3.688369810956867e-5 4.1376814854768925e-5 -1.5256959110819493e-5 -8.352961182797048e-6 2.4454898314762864e-5 0.00019978596269825685 4.0824773398663994e-5 4.0826930720096735e-5 -1.893050370822066e-5 -0.00016185242834288545 -8.299838641211726e-5 0.00011544869362583058 -7.734558580235899e-5 0.00021519793977236657 1.058005823611373e-5 -0.00013971596758606132 -1.2807862682910504e-5 -2.3403151268985283e-5; 1.4452029823584821e-5 0.0001287150770040153 -0.0001227204085480596 3.724922258676977e-6 -0.00022572157119068247 3.831987836381544e-5 -8.61369445956576e-5 -2.8347202842467645e-6 5.2709812585795336e-5 -1.6369499817494449e-6 5.452920355298737e-6 -3.574804956226576e-5 6.179745319948248e-6 -0.00022681913484486738 -1.2225417016474331e-5 -0.00020841570600534172 -0.00011964210367661648 -4.405642554412926e-5 -9.771046112655529e-5 4.582393408770211e-6 1.279444389727286e-5 -4.899304407463678e-6 -0.00017194066109733712 -4.1798252062929653e-5 0.00013309603294703798 0.00016837037022558084 7.676901062586236e-5 0.00012481111829037064 -4.4476342885914614e-5 -2.5695313674097432e-5 7.433213234984073e-5 -0.0001399409413025053; -2.154561465479208e-5 5.150432858451801e-5 2.5178580617187455e-5 -2.600316460346964e-5 -7.918295323062319e-7 0.0001035747228884354 -0.00013327133259780468 5.638532657118124e-7 9.046250866914683e-5 4.7012570721115986e-5 1.939927475131551e-5 8.471099514237133e-5 -1.3095084027454213e-6 2.83680530260553e-5 0.00011058061878385508 7.492115051673118e-5 0.00010264242533549725 -0.00014887198980477286 6.095549193512308e-5 3.204657710693355e-5 -3.231479771829273e-5 0.0001082427080832597 2.9160717676401572e-5 3.665092677628575e-5 0.00017049305839394912 -9.929371648512335e-5 0.00013145237986416665 0.00010169874090887886 0.0002199290867424662 7.933916291177372e-5 -1.497198703927397e-5 0.0001361179858208855; 8.059365104832854e-5 0.0001228933692369155 0.00010454490298122 -6.501604325187862e-5 0.000141806740841796 0.0001418416654383437 -0.0002973935475414948 0.00013654211984664584 -0.00011930155555767399 -0.00012400733207635232 -7.525181942978539e-6 4.0500403386042975e-5 -0.00012410324374962164 4.242029940776127e-5 1.2340163677957058e-5 0.0001694310635253427 -5.843854112234922e-6 6.281136804835601e-6 -5.129388186782471e-5 6.119617367008147e-6 4.1720861602307764e-5 -9.451222972295872e-5 0.00013690271630600052 1.0091829110471399e-5 -7.327930465865207e-5 -5.397081945141265e-5 -0.00012933579232361807 8.831583586703753e-5 0.00016792636638689786 0.0001024469606387095 5.902051311980672e-5 2.8703895125156373e-5; 5.7828014668062435e-5 -2.60179617373651e-5 -7.606664936362656e-5 1.2120574763044502e-6 -0.00016574518436870728 9.797532633340741e-5 4.9694047028178575e-5 -3.220887580143429e-5 0.00010600454928747168 -0.00014204996610830663 5.4250145137600776e-6 -0.00011658054923943446 1.8180916538128358e-6 -6.332745893930293e-5 -7.843026602247006e-5 5.8522261077752096e-5 0.0001197248152845805 -0.0001303632647807433 -7.440039686231475e-5 -4.257092356171358e-5 6.138665279534528e-5 -0.00010526425236209405 -6.127955167601509e-6 -8.158020997255677e-5 5.0315803072203526e-5 8.685005244727047e-5 0.00011510512789639969 -0.0001813763132861522 9.650962832777793e-5 5.461032063861237e-5 -4.7054728497898674e-7 0.00012405952524089075; 1.4065016719710331e-5 8.991325687764831e-5 -9.306216868190685e-5 -7.375267939260193e-5 5.041042121046475e-6 -9.158994414228912e-5 -2.17734546566148e-5 -7.948145555183328e-5 -6.280399299776067e-5 -0.0001293132086828358 7.249993307431563e-5 0.00013251919829435744 -5.722159091583444e-5 6.136234099194517e-5 0.0003050802474064762 -0.00017147016530707777 6.922814593799359e-5 -3.742863607094911e-5 9.504169377413287e-5 0.00015109126927945193 -3.008175605991953e-5 -1.765766189401153e-5 -6.731952518572819e-6 2.0659364443525732e-5 8.472640131740178e-5 -6.104129767827599e-5 5.2361624901265474e-5 6.827168492981509e-5 8.238702814941805e-5 -8.111241706321396e-6 0.00012745820224896337 1.7188292466197696e-5; -0.00014151324393141352 1.69039711300772e-5 3.9172823392715145e-5 -9.506601162401925e-5 8.557893198767695e-5 -8.687890766568447e-5 -0.0001307844660795103 0.0001648599416631467 0.0002656163840702921 -0.00010897655118869527 0.00011940966106804448 -9.406627321995789e-5 0.0002679449524413055 0.00010374077720653263 1.1253408786708026e-5 1.9626829101165952e-5 -3.4735364521927324e-5 -0.00010944998319873386 -6.558945568657732e-5 3.618281590489085e-5 -7.864690377291189e-5 -0.0001130977263409447 -2.3388404940210138e-5 -9.458436505779568e-5 2.1302051039268613e-5 -0.00014693227529905727 5.664350356992931e-5 -0.00015642939359794916 -6.490769845813244e-5 -1.4509701201765998e-5 -2.6734279514928644e-5 0.00011253061421766903; 6.750832170874788e-5 0.0002885415603399553 -4.1597540357168646e-5 -0.00010251949362062795 5.472485344436075e-5 5.382876105651312e-5 2.5124602589505304e-5 0.00011710351890579846 4.4871013133535734e-5 -7.213214584760836e-6 -7.587455848769271e-5 -8.055212956230986e-6 -0.00017300013797429334 -5.184241953964234e-6 -3.6304123201151463e-6 -4.948279920052417e-6 0.00021220761181249334 -6.765604948294022e-6 5.107614259978729e-5 -0.0001899784832035842 -0.00020105672529970248 8.991597195782742e-5 2.9719007843970527e-5 6.108340303769166e-6 -3.439733273771186e-5 -2.921074638796914e-5 0.00010406169230046666 -0.00013104711189006448 -3.802975998913845e-5 -7.075212425623971e-5 -4.016650137563407e-5 5.180155193594338e-5; 0.00013664759677514104 0.00010493811492953162 -3.407413390786241e-5 6.296175484922381e-5 -0.00012016765180132713 -6.230560615049074e-5 -1.4326728376632903e-5 5.477269729627019e-5 -5.069285401814079e-5 -4.7713422134707386e-5 -0.00012611131609272988 3.548299861168112e-5 -0.00010236289603025964 6.208787510243329e-6 -6.809344125767623e-5 -5.082785941166881e-5 6.016661205822338e-5 -2.670097495692402e-5 -9.524022101512482e-5 -8.334860866026575e-6 -6.898259238195836e-5 -9.961557447067679e-5 4.437150124604866e-5 -4.309699437547818e-5 -1.2357735102359759e-5 0.00011343956208466178 -0.00011146541731985714 -0.00016323220410239687 0.0001500194098861424 -2.6008532932754227e-6 7.210859399668425e-5 7.292948936256235e-5; 5.4176014560916435e-5 0.00011744958642910987 -3.478667535428151e-5 -5.6084880310383285e-5 -6.094642052966367e-5 2.2354637582164544e-5 3.572259625462533e-5 0.00011195817383514368 3.4168800954202684e-5 -5.740515738916051e-5 1.4965607039174356e-5 0.0001443829477978481 8.865422019544706e-5 8.225984817099216e-6 -0.00011324010669676741 -1.099775640222724e-5 0.00023044655422955375 0.00012356493421639546 -6.547771594679128e-5 0.00015889919628198096 7.80431527858184e-5 -0.000135975997221698 -3.1519210136775135e-5 -7.451666573437206e-6 -2.6229414328400987e-5 9.098649930484419e-5 5.64408215385943e-6 -6.556662814882866e-5 0.00020551258963341865 -8.154158971928412e-5 -0.00010877379649513898 4.33226322969027e-5; 5.674932778420212e-5 0.00010620458648845813 -6.820869587754284e-5 0.00012445457993381546 3.704191030537265e-7 -4.93174722876504e-5 -7.550701250947198e-5 -8.498743976160093e-5 0.00026575994867672134 -4.134299000069656e-5 -1.6207426808973925e-5 -1.7876962586153208e-5 2.0418534192874756e-5 0.0003985457967857729 -1.758897108883677e-5 8.950174539423292e-5 0.000107419162092999 -0.00011408505655053323 -0.00017899614452544332 -3.093308826718855e-5 -2.589809831333049e-5 -3.0686051316290996e-5 -8.249070945821102e-5 -1.7969470930281625e-5 -7.700098861128972e-6 -0.00011960137474742802 -3.761359970012339e-5 1.430642602572736e-5 -2.0042366329966623e-6 5.1030716137763556e-5 -4.404715785568365e-5 5.196665166950682e-5; 1.284781014547793e-5 0.00017369952606387451 4.469257023303881e-5 -5.3188937593920396e-5 4.859618697667711e-5 -3.120842971253984e-5 4.73981678136916e-5 0.00022708548262601763 -3.0313374148651005e-5 0.00014156758756728833 -3.1706094299423445e-5 0.00010999228893200186 -0.00011376952111601708 8.176869655147468e-5 4.759887874245908e-5 9.928065323622538e-5 -5.389896919365952e-5 1.7197110215808537e-5 -8.498730981372245e-5 6.123954923552484e-6 8.030385710883835e-5 4.3240500096026045e-5 -6.382197111959266e-5 -2.445100429843555e-6 -0.0001967352649023484 5.229867078599763e-5 0.00019560667774059184 0.00022473601770472724 -2.565887130329009e-5 -0.00015548269525749567 3.9399525709424947e-5 -7.326630234598925e-5; -3.634207337622771e-5 2.088709678447365e-7 -0.00016147811425495253 -0.00025208127795365106 -7.489111908823706e-5 6.558435392919536e-5 -4.972669944861328e-5 8.467962057082484e-6 0.00010888971859942597 -9.666423307378063e-5 -0.00012636582710425648 3.0168220969540734e-5 -4.29356007359026e-5 5.179102640370323e-5 0.0002517321915959889 3.660860943030719e-5 -2.1390588021560267e-5 -1.924945923262139e-5 1.595934895281036e-5 -0.00011006901354872947 -1.6827933227132893e-5 -8.286983962049854e-5 -0.00011720874340850424 6.633130373798794e-5 -0.00012876278947670318 0.00021092761721349767 -0.00021815262903569717 6.272932636906208e-5 -8.95780615379998e-5 1.7038169740160985e-5 -0.00011599688901206673 2.893770893321318e-6; -3.4016107333911806e-5 -2.590082071520794e-5 2.6170383038960364e-5 4.758694781657143e-5 -0.00010826881184158074 0.00010158757363996303 8.190518424697982e-5 -0.00011583775044101472 -1.701583054938038e-5 3.3882045003965505e-5 6.168474060076474e-6 -5.2273489514218324e-5 0.00015272413880183192 8.207147625825216e-5 -7.883610082472827e-5 9.861900476124841e-5 -5.069482862864803e-5 -9.610573901710328e-5 0.00021672148291571595 0.00014913592754482014 -1.6586438091790336e-5 4.274063444144955e-6 8.36797612052478e-5 -0.00015203311481429873 -0.0002938036209755671 -5.838911742585867e-5 -6.730683666802384e-5 -4.888135440097486e-5 -2.5223665529948377e-5 -4.780579324261465e-5 0.0002204584729540173 7.925609539114709e-5; -0.0001666659814249432 -0.00020687437839244217 -0.00023723713869134422 0.0001588136942414547 -2.9485853275636284e-5 -0.00015748073977325303 0.00014253298238012807 -7.03679911852815e-5 1.7347770579853817e-5 3.964055164810351e-5 -1.285865121775748e-5 8.755178666300728e-5 -5.382721477095168e-5 -7.148206398729614e-5 -4.1509087651738825e-5 -6.530225121555199e-5 -6.34270332115699e-5 1.4318550227972904e-5 1.541009575573949e-5 -1.576073239579377e-5 1.1156742828957806e-5 -4.399843285814668e-5 4.661423687917339e-5 -0.00012976457995251043 -0.00012869574177894156 -7.569445054419667e-5 6.994701058699554e-5 -0.00011088815566057635 -0.00013837734925932206 0.000152682099240832 -0.00011276188936913964 -9.74760774001503e-5; -8.169449997522386e-6 -9.629866077813023e-5 -2.749951751394028e-5 7.579853876013334e-5 4.243915768890151e-6 0.00020491066128233476 5.4838916853619867e-5 3.256890953270706e-5 -5.205237654518271e-5 0.00013074924093889044 -3.785414197846617e-5 -8.708927022308057e-5 2.2097060296069954e-5 -6.011789925178451e-5 2.164784449196015e-5 0.00010028034624628578 0.00015006639908751455 0.00010989852920392387 1.423390743656392e-5 7.800437040862053e-9 -0.00010460255840808197 -1.3332388696693808e-6 7.091025268158397e-6 -2.4664082288695568e-5 -7.451951502376227e-5 -0.00018406108689742105 4.8594704581037356e-5 -6.732613952845666e-5 -0.00012263118084302385 3.8065917338095475e-5 0.00028555493075616754 -8.676597032245215e-5; -3.0979173546862674e-5 -0.00012354025742176058 -4.063529334750235e-5 -1.6099108992119918e-5 -0.00015040248175288897 -1.9053089575817472e-5 7.366636643048975e-5 -9.54549082359052e-5 9.39422847898921e-5 -3.653791057665217e-5 6.429726132985931e-5 4.0422621592670883e-5 -0.0001585106779508978 -0.00022510555348496307 -6.388038812808113e-5 7.673973092672146e-5 0.00012352947893880917 0.00011814356045427461 0.00015544475396958108 1.7925548005525695e-5 -0.00010915053784053153 4.6653835366970586e-5 -9.004259985310258e-5 0.00017885698113454765 0.00014476075236084576 2.5467869333015973e-5 9.889759755090702e-5 -7.679572948624575e-5 7.926910845192838e-5 3.952778084868587e-6 -8.980698979363379e-5 -0.00012060401562272913; 8.655297916215693e-6 2.8646542614973796e-5 0.0001862533378705357 -6.890015474641808e-5 4.1659203544211436e-5 -7.914236020883659e-5 6.31824065756041e-5 0.00012346070316266136 -5.0884618121246844e-5 -4.4475070625824386e-5 7.084700946574775e-5 -0.00014272522939723085 -3.797834988222578e-5 -0.00018405239215940335 -1.555582047560421e-6 0.00019219300766121406 0.00014830042955577238 -1.2193546626124993e-5 5.6266037700647244e-5 0.00018817896370887504 1.5528454878901565e-5 0.000188125179830192 -0.00014725749615468312 1.0739292247971706e-5 -2.384707049648098e-5 -2.6791897925479415e-5 -2.3869222149419463e-5 -8.28103159613003e-5 -0.00017648935436159592 2.8757035307305496e-5 -0.00018526260218937847 5.360957466177531e-5; -0.00019469413339128505 8.790272156739804e-5 0.00026421970731145094 -2.514135442059985e-5 0.00022039418443138214 -2.7093706407750555e-6 0.00014758182849301452 7.94050506342617e-5 7.655396664314419e-5 -0.0001274819801651459 -7.369670449724104e-5 5.990261022499275e-5 -8.967040732703674e-5 -0.0001301661682405092 3.249568951697534e-5 8.002509319022669e-5 8.987826047497287e-5 2.0905338046871004e-6 1.942605602062975e-5 3.702247747834162e-5 -7.957729351209066e-5 1.9428400697972955e-5 -8.450792414340044e-6 -7.539952569915652e-6 -3.730277196473741e-5 -1.619269416868244e-5 -0.00019354898497430374 -5.381392431762695e-5 0.0001259201385223714 -0.00014563088145540612 -1.2257907222641866e-5 -5.173725687517348e-5; 0.0001832658305099671 0.00018730315752667 4.873667750848107e-5 0.00012693378050221697 6.723699299960085e-5 -7.816474299257582e-6 -6.152219464661023e-5 1.063007095558795e-5 0.00012542457227005373 -2.3193387098651168e-6 -1.931223812812836e-5 -0.00021649697771885712 -0.00016830710263115756 0.00010279347736264319 9.219104651470403e-5 -1.5999423685799495e-5 -3.466831220089023e-5 1.1360453408768865e-5 -7.759486505189327e-5 -6.453782441918129e-5 -3.365534974386546e-5 0.0001279803251416084 -1.9250257883190033e-5 0.00012409973680372477 0.00014309039363034394 5.008457828584965e-6 -7.778532779435736e-5 2.7513466372901684e-6 9.842152958068894e-6 0.00011546915595295366 7.70078473758055e-5 -2.6808521672552765e-6; -5.583514553945785e-5 1.8644101365498843e-5 0.00016949703444043342 -7.916281152501075e-5 -3.724749309111667e-5 3.6754954854414504e-5 -8.429108835042777e-5 0.00011122951422367817 0.00026308692053208805 -5.322918128439611e-5 6.375950783040534e-5 0.00010887232951136974 3.636731366060349e-5 -8.48152081670261e-6 -4.805220283668747e-5 -5.766551817868457e-5 7.361652409673297e-5 0.00018777693207158166 8.252097926815672e-5 -0.00016758785182827902 -3.8437574184345406e-5 -0.00010196703784385183 9.225240523010748e-5 2.7461683421963115e-5 -5.4995794707122034e-5 8.089133938557201e-5 6.931388656157821e-5 -6.0588121195173583e-5 -0.00021018528037456506 1.6720420026840758e-5 -5.0317459462546636e-5 4.783665303040001e-5; -0.00012093203871197015 0.0002083579104378551 0.00011558192751683688 -4.850083745121253e-5 -3.338232885151432e-5 1.7867812189722978e-5 1.9113963631314983e-5 3.3847399450240974e-5 -0.0001269783740412847 -9.350070352818937e-5 0.00011695233230559557 5.2107244542201004e-5 -0.00018209100265674494 -3.233601340478582e-5 -3.843244370226109e-5 -0.00015595096077456973 -0.00011758381764091689 -0.00012012453110808415 3.760739970451312e-5 -0.00021085689398612454 -6.417601137807687e-5 -3.4419705784312284e-5 6.36689812977453e-5 6.15962355965145e-5 -8.084705653744364e-5 -8.927380143584962e-5 1.118249498679696e-5 -5.6699877618006513e-5 4.991643187670264e-5 7.197805213356238e-5 -8.827734504342216e-6 0.0002576674971947049; -5.960567038988192e-5 -2.9042786695808063e-5 0.0001443254771629145 3.0775198183138366e-6 9.514622463479766e-5 -7.13768411065109e-6 -1.335591121010382e-5 0.00015230782839803104 -0.00019427353588530652 -6.9222907541674295e-6 -0.00014516096750872383 -1.1385738321271292e-5 0.0001063070124700489 -8.322771798252916e-5 -7.283483471387856e-5 2.692554182880988e-5 8.973799078825934e-5 -0.0001936806326512437 0.00013873511154537574 -2.8288033422591803e-5 -0.0001341416606692779 -4.293712815984108e-5 -4.526847230868526e-5 -4.307723399969137e-5 6.452414920650248e-5 9.999964572623316e-6 0.00017860146574543413 -3.4796874092586435e-5 9.300475751383139e-5 8.425858352433968e-5 8.63781427369123e-5 9.450108002506316e-5; -4.735592169914714e-5 8.181646757997327e-5 3.0223436122223345e-5 -0.0001770129552390052 -3.3848768500574204e-5 2.081445880067968e-6 -0.00016612519771322563 0.00014089484961707252 0.00010119447456210772 -6.112499481994504e-5 -5.549485881808337e-5 -5.5344104614304934e-5 2.9081139880619506e-5 -7.480116851391343e-5 -0.00011722456343002684 -7.121225876496671e-5 -0.00018338142631571614 -0.00020002699834756134 -8.744332676748087e-5 0.00011240829685241759 0.00016320973140657134 -4.2704209183536266e-5 0.00019767725229733341 2.905028072547913e-5 0.00011074273556718773 -9.00215477961679e-5 2.0465023633492998e-5 3.8153042121603574e-5 -8.937731995708651e-5 -4.4982024112212305e-5 -4.224272793383602e-5 -5.6322066077226814e-5; -8.498145360612557e-5 0.0001350233647961049 -6.780301226648499e-5 5.6290824081555855e-5 -2.32517289853021e-5 -0.00010345364024509407 -0.00012106458633319214 3.3407657260580516e-5 -0.00011907381884633262 0.00010782570426705516 0.0002483061976952386 -1.718680722227492e-5 6.622970942295416e-5 1.6380219326278397e-5 -0.00014495487301722804 3.99806696328813e-6 4.74401236849767e-5 0.0001428107349190781 -1.638359243337793e-5 -4.897477664166118e-5 4.705140290873734e-7 -0.00013702243485134823 0.00019797139940670746 9.642951737508157e-5 -0.00016748727683559835 0.00011347572546107054 -5.905863543852426e-5 -8.708486746876524e-5 -3.054005390844233e-5 -2.632987378568616e-5 8.187381864877137e-5 0.00011939267971207665; 8.503846110241086e-5 4.771599268704228e-5 -0.00017844623581342558 0.00016355844926718713 0.00011235786295427112 1.7527698937813713e-5 -0.00020054922716528098 5.805174877449e-5 -0.00010292597217786936 -5.145547865792876e-6 -6.278628207812252e-5 4.221390626324455e-5 -8.897059351253922e-6 0.0001569741587357235 -5.1149941291183915e-5 7.655852167798162e-5 4.6796286178762644e-5 2.088210093487015e-5 -9.916675408905787e-5 -4.655106314177778e-5 0.00014374677338335696 0.0001762549953039366 -7.486425530234465e-5 -0.00010843106357219679 5.2268788558870384e-5 -0.00018888460109311975 -6.002221126411461e-5 -0.00011650058805545981 -5.735455780242664e-5 0.00015992044135626466 -2.0393451538732256e-6 0.00020783506696742126], bias = [3.067081670343341e-9; 1.077991352344351e-9; 3.242102559632206e-9; -1.910756794892824e-9; 6.726523689169542e-9; -4.5355097663615086e-9; -2.069052107360952e-9; -3.7727332021712865e-9; 6.966874492846697e-9; 2.5761006081278366e-9; -1.1545682999071793e-9; 2.5460651155831627e-9; -1.296199932748172e-9; 1.606989937121807e-10; -1.930346204166116e-9; 3.545170611251893e-9; 1.1148287814046742e-9; 4.009640573554416e-9; -4.21301108379211e-9; 4.788077514530398e-10; -6.505189467440968e-9; 1.3012460812064987e-9; -5.361777486472568e-10; 5.663430308648547e-10; 4.660946734265569e-10; 3.6958406374817986e-9; 1.8555882082550431e-9; -2.5381582260796524e-9; 1.1244309894392014e-9; -3.1673628472358864e-9; 1.0764913570603444e-9; 1.0193528247116395e-9;;]), layer_4 = (weight = [-0.0004890177519704853 -0.0007076403653472109 -0.0007211410214461274 -0.0006991902528441092 -0.0008861143008196651 -0.0006345514947296007 -0.0007220250631264328 -0.0008069504040575455 -0.0005691055478941535 -0.00043994704633209896 -0.0008205050128452781 -0.0006958338198484535 -0.0008500602599675031 -0.0008717260065601919 -0.0008138155913083992 -0.0006921712428238562 -0.0007188541226209268 -0.0004899240400506721 -0.0006062540261484343 -0.0006317998388074185 -0.000534507339738065 -0.0008172835988397439 -0.0004958206376910766 -0.0008500348807436666 -0.0006299140490201394 -0.0008758227976199528 -0.0008339446416863287 -0.0006077338730261422 -0.0006103338643463324 -0.0007623898840051351 -0.0005959962713431696 -0.0006694464026883215; 0.00020455299704231456 0.00014090093269610458 0.000323015577893387 0.00025414974675671555 0.00023432921972551843 0.00036114332823051527 0.00021671201864469005 0.00023231695042985617 0.00018795433265487662 0.00013299254890080146 0.00023169367956826323 0.00012308374099900956 0.00022360793578335016 0.0003156333289768353 0.000279893950134309 0.0003279980649032129 0.00019910610847046398 0.00016896560433531305 0.00019741280278746086 0.00030224634461944924 0.0002656012509362382 0.0002594282986570807 0.00031349106042084584 0.00025291699223991247 0.0002843808201045841 0.00019787723974853483 0.0001867067277814646 3.8045412672215016e-5 0.000270078851351094 0.00017948262959211245 0.00026180673558194226 0.0003075495318982846], bias = [-0.0007413768188562973; 0.00022027043037844922;;]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(compute_waveform(
    dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(ax, tsteps, waveform_nn_trained; marker=:circle,
        alpha=0.5, strokewidth=2, markersize=12)

    axislegend(ax, [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb)

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(LuxCUDA) && CUDA.functional()
    println()
    CUDA.versioninfo()
end

if @isdefined(LuxAMDGPU) && LuxAMDGPU.functional()
    println()
    AMDGPU.versioninfo()
end
Julia Version 1.10.3
Commit 0b4590a5507 (2024-04-30 10:59 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 8 default, 0 interactive, 4 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_AMDGPU_LOGGING_ENABLED = true
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 8
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.4, artifact installation
CUDA driver 12.4
NVIDIA driver 550.54.15

CUDA libraries: 
- CUBLAS: 12.4.5
- CURAND: 10.3.5
- CUFFT: 11.2.1
- CUSOLVER: 11.6.1
- CUSPARSE: 12.3.1
- CUPTI: 22.0.0
- NVML: 12.0.0+550.54.15

Julia packages: 
- CUDA: 5.3.4
- CUDA_Driver_jll: 0.8.1+0
- CUDA_Runtime_jll: 0.12.1+0

Toolchain:
- Julia: 1.10.3
- LLVM: 15.0.7

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.735 GiB / 4.750 GiB available)
┌ Warning: LuxAMDGPU is loaded but the AMDGPU is not functional.
└ @ LuxAMDGPU ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/LuxAMDGPU/sGa0S/src/LuxAMDGPU.jl:19

This page was generated using Literate.jl.