Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux, ComponentArrays, LineSearches, LuxAMDGPU, LuxCUDA, OrdinaryDiffEq, Optimization,
      OptimizationOptimJL, Printf, Random, SciMLSensitivity
using CairoMakie

CUDA.allowscalar(false)

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end
one2two (generic function with 1 method)

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end
soln2orbit (generic function with 2 methods)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end
d_dt (generic function with 1 method)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0)<1e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio1 "mass_ratio must be <= 1"
    @assert mass_ratio0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit1, mass1, orbit2, mass2)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end
compute_waveform (generic function with 2 methods)

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; markershape=:circle,
        markersize=12, markeralpha=0.25, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defiing a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(Base.Fix1(broadcast, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1e-4)),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1e-4)),
    Dense(32 => 2; init_weight=truncated_normal(; std=1e-4)))
ps, st = Lux.setup(Xoshiro(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-0.00022478317; -0.00014757249; -0.00015495336; 9.269136f-5; -6.675013f-5; 9.531766f-5; 0.00012224248; 0.00014550178; 0.00010277838; -7.804003f-5; 8.9768226f-5; 4.1223757f-5; 3.431828f-5; 1.0803946f-5; -9.7221935f-5; 1.8967345f-5; -4.4868255f-5; 0.00013739581; 3.9240826f-6; 4.6769565f-5; 9.524104f-5; -1.8312436f-5; 9.288281f-6; 9.24653f-5; -0.00028629895; -0.00023564558; 0.000188849; -5.3668577f-6; 0.00014410698; 8.8745415f-5; -7.577554f-5; -2.529135f-5;;], bias = Float32[0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0;;]), layer_3 = (weight = Float32[0.00010002282 -0.00012869238 -4.391901f-5 -2.6611267f-5 -0.00011390923 0.00013940266 0.00012123657 9.8293585f-6 3.6267407f-5 -0.000306966 6.470216f-5 -4.4106066f-5 6.623466f-6 0.0002559341 -2.2499293f-5 0.00021290561 4.445685f-6 -0.00010960925 -0.00013266462 4.9833645f-5 6.958819f-5 -0.00011970948 -0.00012769835 -2.965727f-7 0.00017064325 0.000106795866 -6.1012044f-5 -6.729497f-5 -4.5553046f-5 -4.6125417f-5 9.517618f-5 6.0932856f-5; 3.0309782f-5 -5.7506786f-5 1.3833771f-5 -0.000110737936 0.00025416425 -1.0774371f-5 0.0002625921 0.00014973998 -0.00010504948 -6.450833f-5 -0.0001450622 8.426773f-5 -0.00013259467 8.929161f-5 -7.3066614f-5 9.645316f-6 -3.98393f-5 -0.00017194082 -1.0333052f-5 9.100387f-6 5.7323967f-5 7.715402f-5 -0.0002186616 0.00014116321 1.3228973f-5 2.704261f-9 -0.000121639896 -5.3741146f-6 -2.3053239f-5 1.1528426f-5 4.7809808f-5 -0.00026921515; 0.0001410839 -0.000108244974 -0.00027407528 2.3870341f-5 -0.00013748102 0.00013524128 3.3717824f-5 0.0001079182 2.7724496f-5 -9.469286f-5 9.0534675f-5 1.1350232f-5 9.25568f-5 2.355729f-5 7.9208476f-5 6.736407f-5 2.1860567f-5 0.00014652056 -4.1445852f-5 -0.000107613945 9.659725f-5 -4.347962f-5 -7.347811f-5 -2.7132908f-5 3.9148246f-5 2.1844696f-6 6.398158f-5 -3.3282402f-5 -5.81314f-5 -0.00011577405 -7.932504f-5 -2.1201004f-5; -0.00015654835 5.1739924f-5 4.619325f-6 -0.0001982151 2.9087307f-6 0.0001087846 -0.000116469455 -5.229869f-6 0.00011958847 9.610897f-5 0.00014536902 1.3524706f-6 0.00010315249 -0.00024865972 -6.195806f-5 2.5924615f-5 -0.0001390205 0.00013308492 -9.755733f-5 1.5774483f-5 1.086926f-5 -5.365305f-5 -1.6457343f-5 -0.00020141757 -5.1551055f-5 0.000113519316 -4.446919f-5 0.00014995283 0.00019539411 -3.4264867f-5 4.664481f-5 1.9431085f-5; -5.7666086f-5 6.107596f-6 -0.00026658052 8.630304f-5 -0.00013964107 -8.444826f-6 1.5843076f-5 -3.5323505f-5 -5.1554434f-5 9.5875956f-5 -3.8919745f-5 -2.6361098f-5 9.6110554f-5 -8.288298f-6 -5.524749f-5 -9.192253f-5 7.8501005f-5 -3.9269464f-5 4.2311673f-5 -0.0001074762 -0.00022981285 5.9673854f-5 -9.5015545f-5 -0.00017922447 7.7534816f-5 0.0001472331 6.512855f-5 -6.8470667f-6 0.00011840324 -7.874325f-5 0.00014587432 -6.405526f-5; 0.0002381581 4.875223f-5 0.00011359465 0.00019395907 -1.0349455f-5 -0.0001586724 4.323175f-5 0.00018281459 -7.063995f-5 -0.00011452466 -2.0377405f-5 -1.4138527f-6 8.062809f-5 -1.1642779f-6 -6.0468244f-7 -6.649398f-5 7.4638054f-5 6.725995f-5 -5.7677167f-5 -8.465831f-5 0.00011732817 -1.0294856f-5 -2.8004265f-6 -1.1277304f-5 0.00012695315 -0.00014742103 -1.1215614f-5 5.7509165f-5 1.7984039f-5 0.00015558422 9.291793f-6 -4.9424525f-5; 5.5550485f-5 -8.0176695f-5 9.890225f-5 -8.8013556f-5 3.0521358f-5 2.0218697f-5 3.4865136f-6 0.00014468907 -1.0977402f-5 -2.4844996f-5 0.00018342484 -0.0002889353 0.00013465117 -2.5303507f-5 -5.0206785f-5 -8.998733f-6 3.6627713f-5 0.00012804789 9.785308f-5 0.00016550854 -3.9509796f-5 -0.00011322541 -9.811461f-5 -9.468847f-5 -3.065937f-5 5.61725f-5 2.4350269f-5 3.1761843f-5 0.00010984898 6.442621f-5 -3.726842f-5 5.546712f-6; 4.616399f-5 5.1614286f-5 2.0060299f-5 -7.742845f-5 -1.3718129f-5 -9.95445f-5 -1.6840617f-5 -3.916906f-6 0.00011855612 4.385685f-6 0.00015394311 8.84942f-5 3.6714562f-5 -2.0302468f-5 -0.00012429235 -7.1781587f-6 5.1335148f-5 -0.00010800273 -5.9597372f-5 5.709248f-5 -5.0078077f-5 0.00017321442 -6.67177f-5 -0.00013019399 5.617815f-5 5.3015203f-5 -0.00013182378 0.00013977538 -0.00016327389 0.00012336382 1.6562151f-5 0.0001686086; -8.3522435f-5 -9.0048765f-5 -2.4646577f-5 -0.00017855706 -0.0001145417 0.00013401051 -4.1139516f-5 0.0001396352 8.5494394f-5 -7.662247f-5 -1.9241153f-5 0.00017821316 0.00013309832 -8.9227055f-5 9.732358f-5 -0.00014777435 7.2495513f-6 -0.00013150321 4.5573033f-5 -0.000103681006 0.00029668573 6.1201995f-6 2.8813964f-5 3.4781717f-6 7.403685f-5 -1.3061934f-5 -0.00011647075 5.705374f-5 5.1346706f-5 -2.8001743f-5 -8.166242f-5 3.1602438f-5; 0.00011473602 4.581222f-5 0.00010036754 8.079245f-5 7.533182f-6 -8.798934f-6 2.4572842f-5 -2.9947783f-5 8.345943f-5 4.1200256f-5 -0.00015817293 -7.8054094f-5 0.00011209196 -3.29736f-5 0.00011541845 -8.3498504f-5 0.0001193326 -9.162597f-5 2.6011946f-6 7.128562f-5 0.00016970208 -1.5366186f-6 2.173176f-5 -0.00016322733 3.4756304f-5 5.0335122f-5 -0.00011387682 -7.8517696f-5 7.828914f-5 -0.00021620384 1.21734465f-5 6.658776f-5; 3.8045855f-5 -0.00026755536 -2.8755036f-5 7.589805f-5 -1.7441707f-5 -4.0950257f-5 -0.00015395563 0.00012464295 6.711842f-5 -0.00012230499 -0.00024676838 -1.35775035f-5 -2.3424684f-5 -0.00014865467 -1.4795824f-5 -2.6561087f-5 -7.834509f-5 3.0621803f-5 2.2768887f-5 -0.00010052241 -0.00016816826 6.15117f-5 6.213778f-5 8.141208f-5 9.0760266f-5 -9.788292f-5 8.187855f-6 -8.3786246f-5 -2.5942833f-5 -5.399234f-5 1.9729772f-5 1.9757257f-5; 2.9481032f-5 0.00019273185 -0.00014670652 -7.502017f-5 -4.333033f-5 -3.0126106f-5 7.749218f-5 0.00023461548 -3.057164f-5 3.0403324f-5 0.00013767186 8.029191f-5 -3.9742124f-5 1.7609842f-5 -0.00014911633 9.8781544f-5 5.7636906f-5 8.625198f-5 0.00012202725 -0.00013073358 -9.109478f-6 8.085929f-5 0.0001234032 -9.226178f-5 2.5111412f-5 9.459688f-5 -8.404779f-5 7.1376853f-6 -2.8122198f-5 -0.00012714222 0.00014290454 7.2704286f-5; 0.00012738499 -3.8577535f-5 -1.030564f-5 -0.000106949636 0.0001339977 3.9823997f-5 4.018373f-5 -5.6419758f-5 0.00010091828 -0.00014003734 -9.995799f-5 -6.8555826f-5 0.0002944183 2.5830002f-5 -8.6902626f-5 3.1703472f-5 -8.613265f-5 -4.4418608f-5 2.5832413f-5 0.0001425472 1.7966672f-6 5.2554013f-5 0.0001459187 -9.254999f-5 -5.3445347f-5 0.000121297744 0.00012355874 0.00011897578 -4.1606792f-5 0.00010781215 0.00018431008 0.00012606039; 0.0001353546 -9.296944f-5 9.9404015f-6 0.00020451576 0.00017484857 -4.285188f-5 -0.000114315946 -4.7599253f-5 -6.524579f-5 7.224879f-5 2.4338557f-5 0.000116835785 4.305328f-5 2.788327f-5 4.594144f-5 -0.00010692767 -8.730047f-5 -2.166005f-5 7.2841416f-5 0.00013915429 -5.0181334f-5 5.5363876f-6 -5.4738543f-5 -4.1780637f-5 1.2366734f-5 3.6159658f-5 2.1845495f-5 -6.931592f-5 8.884552f-5 6.0043967f-5 4.6574743f-5 -2.192738f-5; 0.00012084936 1.7285647f-5 -2.257981f-5 -0.000113502385 -0.00016654277 3.567039f-5 0.000108881526 0.00014828937 -0.00015351866 -7.459572f-5 -5.0482628f-5 -0.00019637724 0.0002527021 2.0702579f-5 -1.7354685f-5 6.0868197f-6 0.00017867453 -2.6931948f-5 0.00010879932 -0.0001589152 7.255969f-5 -2.91685f-5 -2.6709371f-5 -0.0001344635 -4.027194f-5 0.00011876549 0.00013488093 0.00013301367 -0.000113330534 1.9927278f-5 -3.8568684f-5 7.483079f-6; 1.8358276f-5 5.6514014f-6 2.713095f-6 -2.7231723f-5 7.806908f-5 -0.00014067705 -9.196286f-5 3.1128355f-5 0.000107997344 -5.403785f-6 6.437139f-5 -0.00014994139 -1.833281f-5 0.00016433078 0.00011488936 5.586978f-5 -0.00025687428 -0.000149289 -3.3665925f-5 -3.6165926f-5 -9.000683f-6 9.575764f-5 0.00016443792 7.3836436f-6 5.7788537f-5 3.925063f-5 0.00023229179 9.584901f-6 0.0001006891 6.2911175f-5 -8.0032776f-5 0.00019793026; -0.00013544293 -2.4913448f-5 0.00012132226 -9.498176f-5 4.6829493f-5 5.872477f-5 -7.1489354f-5 -6.7895875f-5 -7.024067f-5 9.6594056f-5 4.154121f-5 7.6403594f-5 -6.6780725f-5 -0.000106721374 -1.6454027f-5 3.2579694f-5 0.00017469058 -7.155961f-5 0.00011838917 0.00020437421 -5.3676515f-5 0.00013281967 0.00018389714 -1.6420998f-5 -3.4860896f-5 -2.7415583f-5 3.3138294f-5 2.5885292f-5 0.00018559757 -0.00014313504 2.4932313f-5 -0.0002614195; 7.682f-5 -0.00022720349 -0.00018443566 2.2957202f-5 -2.2538018f-7 0.00016329468 0.00016886157 8.873607f-5 -1.9370782f-5 2.7808346f-5 8.543597f-5 0.00011046133 -5.9354727f-5 -5.102017f-5 -8.298613f-6 8.125984f-5 -2.040442f-5 -9.018707f-5 0.00010328428 5.1415705f-6 5.402005f-5 0.00014285462 3.3918102f-5 -0.00010447185 -4.964776f-6 -9.409044f-5 6.369577f-5 -9.165739f-5 0.00018064384 -0.00011941614 -2.1237876f-5 1.7896205f-5; 8.110062f-5 0.00015785586 -0.000101779864 5.765088f-5 -3.291863f-5 2.4336248f-5 -6.0482987f-5 -1.1735118f-6 -0.00012649497 -8.098157f-5 -6.462055f-5 -8.7426844f-5 -0.00014003356 -1.6164565f-6 1.558059f-5 -3.657587f-5 9.313846f-5 4.5238645f-5 -3.3639044f-5 -0.00016472177 -0.000105895415 4.2335465f-5 -7.116041f-6 7.707023f-5 -6.181619f-5 6.955677f-5 8.227454f-5 -0.00014962995 7.841885f-6 -1.6423162f-5 -0.00021802692 9.3289455f-5; 7.794439f-6 3.7744878f-5 4.3177704f-5 4.7335525f-5 0.00026418985 -1.0784334f-5 7.7132034f-5 -4.5300636f-5 7.824199f-5 0.00023536263 -9.0300426f-5 0.00013119685 -3.4900684f-5 0.00019343746 -5.0441784f-5 8.112255f-5 0.00010123477 2.0318299f-5 -0.00015517992 -0.0001732053 4.4359645f-6 0.00021859378 1.2809434f-6 -0.000103867904 3.024207f-5 -0.000249423 0.00012114063 0.000113745395 -7.983323f-6 -0.0001262892 -7.663049f-5 -9.3463226f-5; -6.7503865f-6 6.3668536f-5 -0.00011261818 -6.8528374f-5 0.0001131489 0.000110453715 0.00016307332 -0.00017087662 -4.1190448f-5 -2.3814651f-5 -7.529973f-5 -0.00010907419 -0.0001707087 5.0268587f-5 -2.713221f-5 1.4539705f-5 0.000110737 3.1198786f-6 -6.245283f-5 1.9835235f-5 4.3252305f-5 8.921341f-7 0.0001500738 0.00022314572 6.499398f-5 4.4207787f-5 0.00011729832 -2.1809676f-5 -5.82046f-5 2.6498441f-5 -6.3228385f-5 5.32684f-5; 9.010043f-5 9.8365126f-5 4.800319f-5 -7.424996f-5 -6.683397f-5 6.224268f-5 8.455168f-5 8.0808604f-5 1.7500886f-5 0.00017354896 -5.7063677f-5 1.325431f-5 1.8852767f-5 7.1699455f-5 6.8227914f-6 0.00015031204 -0.0002455525 -0.00014917058 1.7837681f-5 0.0002150394 1.5505297f-5 -1.5415226f-5 -0.00026166573 3.7217338f-5 -2.3737037f-5 2.0153846f-5 9.584003f-5 -8.4383515f-5 4.8606536f-5 7.243887f-5 -3.4307886f-5 4.65041f-5; 0.00015264371 9.4761075f-5 -7.2132425f-5 0.00011528105 0.000121227065 0.00012555433 1.787181f-5 6.77763f-5 -4.2653453f-5 -8.333443f-5 -3.8327116f-6 0.00013706902 -0.00012886072 -7.2881376f-5 -5.6214525f-5 -3.5672733f-6 3.893894f-5 -2.1584985f-5 1.2203243f-6 0.00024007774 -3.5780076f-6 -0.00013980329 6.044306f-5 -3.2919117f-5 6.4494833f-7 0.00015589148 -9.556244f-5 -2.431852f-5 -6.233793f-5 -0.000104452876 9.635999f-5 -5.99543f-5; 0.0001104488 0.00012269935 6.781889f-5 0.00016702045 -1.5506734f-5 3.989357f-5 1.5545189f-5 3.087678f-5 7.637603f-5 -5.58105f-5 0.000115002425 0.00025789958 -0.00010239304 -5.313668f-6 -0.0001531653 -0.00020709552 0.00013540988 -0.0001952725 -2.9105895f-5 0.00010601422 0.00011143744 9.824675f-6 1.09047f-5 -7.7459044f-5 -4.0916275f-6 0.00010907547 4.220771f-5 -9.5195435f-5 0.00022232848 -4.7265097f-5 3.0194107f-5 -2.7736498f-5; -4.7049845f-5 9.086953f-5 8.62822f-5 -2.6047472f-5 7.7845005f-5 4.397916f-5 -8.108555f-5 2.8170154f-5 -0.00012506326 2.2860559f-5 -0.00010568719 7.8074176f-5 -8.600336f-5 -6.761146f-5 0.00010643894 -2.898733f-5 6.201488f-5 5.1544554f-5 5.5696863f-5 -0.00013348617 2.3037757f-5 -8.377122f-5 -3.0343057f-5 -9.990133f-5 7.8406694f-5 0.00011789547 -0.00015724714 5.6253673f-5 0.00020821806 0.00013089468 0.00016051803 0.00014289122; -4.9461803f-5 -1.8561709f-6 2.218925f-5 2.3032972f-5 -0.00023348974 1.4643718f-5 3.6302226f-5 0.00010657718 -1.685365f-5 0.00015320272 9.220302f-5 -5.0303817f-5 -7.104914f-5 0.00014492426 8.498842f-5 5.6634748f-5 -1.5997208f-6 0.00016457123 -5.472479f-6 -2.4434556f-5 3.9539114f-5 3.9642597f-5 -1.663943f-7 9.0411304f-5 4.8420858f-5 -6.166854f-5 3.8983348f-5 -1.3344237f-5 1.3300328f-5 6.466593f-5 -3.0878724f-5 -0.00015858795; -0.00013847774 -3.4615816f-5 -3.328201f-5 -8.993783f-6 -9.2572074f-5 3.7261165f-5 9.0857044f-5 -4.7666796f-5 -2.0145673f-5 -0.0002210649 -3.2024338f-5 -0.00022480635 0.0002288795 -4.8793205f-5 0.00015902355 -9.381473f-5 0.00013191598 5.129559f-5 -2.6479904f-5 0.00012022947 9.56596f-5 9.6668264f-5 5.802184f-5 0.00016662564 7.124059f-5 -0.00016689971 -9.391732f-5 -0.00010263432 -0.0001365649 -4.8080423f-5 -0.00010307863 -0.00012316134; 0.00020088724 5.808666f-7 5.7375284f-5 -4.2141994f-5 4.359235f-5 -0.0002235698 -0.00024546034 8.4438456f-5 0.00012988095 -3.3494307f-5 4.9682617f-6 3.3460994f-5 0.00010702192 2.267903f-5 2.2882075f-5 0.0001513779 9.1450944f-5 -0.00023648462 -2.8536468f-5 5.2144216f-5 -0.00019923995 0.0001493044 5.4978987f-5 3.4781567f-6 -0.00020486195 -0.00025362746 -0.00011821811 -0.00019457978 1.3299679f-5 2.7499997f-5 -3.25669f-5 -5.932132f-5; -6.8999696f-5 -6.0698203f-5 4.4967863f-5 -4.3293472f-5 1.8666402f-5 -7.748784f-5 -0.000115003815 2.5786717f-6 5.108097f-5 5.37271f-5 -0.0002972422 -4.0252675f-5 0.0001499504 0.00018190655 -2.9402709f-5 -0.00010817025 0.0001130789 3.465551f-5 5.7016598f-5 -4.867952f-5 -2.884888f-5 0.00010242025 -6.382486f-5 -0.00014510311 6.2672494f-5 -0.000154137 0.00015790152 -0.00018315826 3.690912f-5 -0.00018184977 -5.4926113f-5 0.00017903902; 0.000116034316 -4.4988625f-5 -4.59555f-5 -4.6120218f-5 0.00017863355 9.2869f-6 2.984834f-5 -0.00010062391 7.3249816f-5 2.4249583f-5 0.00016910807 -7.084725f-5 4.6613015f-5 -2.2334123f-6 -2.2657416f-5 8.041271f-5 -0.00010055951 -8.505651f-5 -2.740058f-6 -8.6944594f-5 5.1251238f-5 -0.00015105621 -0.00015197477 3.524098f-5 0.00010522088 -1.7702276f-5 9.7689546f-5 4.484663f-5 0.0002555742 -2.5455427f-5 0.00014188167 -2.2847875f-5; 9.625256f-5 2.5640957f-5 7.3999894f-5 -0.000119444274 8.5259046f-5 -7.4252974f-5 3.139373f-5 -1.2134027f-5 0.00015838892 0.00012898835 8.380874f-5 1.160342f-5 -5.5289656f-6 6.567574f-5 3.9646634f-6 2.1519369f-5 -0.00012693758 -9.2087445f-5 -0.00013727683 5.8879017f-5 7.652618f-6 -0.00027487444 0.000113362 8.8388355f-5 0.000101502876 -0.00014425952 5.402064f-5 1.2522708f-5 -0.00015310063 5.1685634f-5 -1.8483766f-5 -2.2978125f-5; -0.00018165885 -4.426379f-5 -0.0001353045 -5.9922542f-5 -5.9081663f-5 -2.435207f-5 -0.000109316614 2.1483207f-5 -7.478356f-5 -6.0424787f-5 8.896262f-5 -5.7591187f-5 2.6395348f-5 4.8323345f-5 0.00011701666 5.383323f-5 3.6250938f-5 -2.9462386f-5 5.218619f-5 -0.00015514308 -2.327491f-5 -2.3575294f-5 4.1773023f-5 0.00010307957 -0.0001361062 0.00013519263 2.2158747f-5 5.0823684f-5 0.00014279384 -0.00015585977 8.511958f-5 8.525666f-5], bias = Float32[0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0;;]), layer_4 = (weight = Float32[4.542509f-6 -0.00013102415 7.2268245f-5 -0.00015671842 -1.46926795f-5 0.0003152379 8.031065f-5 -0.00021266112 0.00010255894 0.00011768247 0.00010438796 -4.8229373f-5 -0.00020793946 -0.00021626358 6.078258f-5 -0.0001287228 4.5852463f-5 -0.00016342802 -3.3170481f-6 4.4717486f-5 -0.00021621436 -0.00013303205 -7.9071484f-5 -1.3936201f-5 7.341936f-6 -0.00010476734 0.00025791454 -9.2659386f-5 -6.448117f-5 -4.9705126f-5 0.0001281067 -9.697857f-5; 6.882313f-5 3.267184f-5 0.00020837394 3.3898337f-5 4.6119967f-5 -8.703703f-5 0.00013592109 -7.099213f-5 -7.578569f-5 3.3541216f-6 -0.00020368487 0.000112562535 4.1836352f-5 0.00014277251 9.987409f-5 6.9864077f-6 2.29218f-5 7.288357f-5 -3.6168527f-5 -8.900577f-5 -2.6943024f-5 -0.00019462623 -0.00012381407 -0.00014623452 -0.0002069357 -1.8087667f-5 9.846663f-5 2.7298278f-5 4.7592526f-5 4.6533085f-5 -3.0006979f-5 8.1224694f-5], bias = Float32[0.0; 0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most DL frameworks, Lux defaults to using Float32, however, in this case we need Float64

julia
const params = ComponentArray{Float64}(ps)

const nn_model = StatefulLuxLayer(nn, st)
Lux.StatefulLuxLayer{true, Lux.Chain{@NamedTuple{layer_1::Lux.WrappedFunction{Base.Fix1{typeof(broadcast), typeof(cos)}}, layer_2::Lux.Dense{true, typeof(cos), PartialFunctions.PartialFunction{nothing, nothing, typeof(WeightInitializers.truncated_normal), Tuple{}, @NamedTuple{std::Float64}}, typeof(WeightInitializers.zeros32)}, layer_3::Lux.Dense{true, typeof(cos), PartialFunctions.PartialFunction{nothing, nothing, typeof(WeightInitializers.truncated_normal), Tuple{}, @NamedTuple{std::Float64}}, typeof(WeightInitializers.zeros32)}, layer_4::Lux.Dense{true, typeof(identity), PartialFunctions.PartialFunction{nothing, nothing, typeof(WeightInitializers.truncated_normal), Tuple{}, @NamedTuple{std::Float64}}, typeof(WeightInitializers.zeros32)}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}, layer_4::@NamedTuple{}}}(Chain(), nothing, (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()), nothing)

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(ax, tsteps, waveform; markershape=:circle, markersize=12,
        markeralpha=0.25, alpha=0.5, strokewidth=2)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(ax, tsteps, waveform_nn; markershape=:circle,
        markersize=12, markeralpha=0.25, alpha=0.5, strokewidth=2)

    axislegend(ax, [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"]; position=:lb)

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    loss = sum(abs2, waveform .- pred_waveform)
    return loss, pred_waveform
end
loss (generic function with 1 method)

Warmup the loss function

julia
loss(params)
(0.1710730202515836, [-0.02424933242660149, -0.02346541744255887, -0.0226815024585165, -0.02135697382994346, -0.019464174648250746, -0.016962962561949065, -0.013800641165922008, -0.009909008035922758, -0.005206626567589536, 0.0004015190696544761, 0.007014837272417639, 0.014716258062701342, 0.023524834987610013, 0.033267965708927, 0.04330464722252075, 0.051871330566536276, 0.05469810095382162, 0.04274611874887868, 0.0024910822901808532, -0.06567377777625286, -0.11027845468148388, -0.07686666067311705, -0.0072721787217171745, 0.03872266290297432, 0.05434645480768314, 0.05304348030037345, 0.04490306189167211, 0.034843357292549615, 0.024920348633773237, 0.015898061854952267, 0.007991482142431335, 0.0011986405068756289, -0.004562215136522249, -0.00939302361565574, -0.013392680779022424, -0.016645878856822095, -0.019224587676866394, -0.021185055403262376, -0.022570019944592134, -0.023409634513427105, -0.02372247121791157, -0.023515897806377636, -0.022787378927456334, -0.021522167825007852, -0.019695452147766815, -0.01727004037480029, -0.014195646189975491, -0.01040836626978794, -0.005829950178071315, -0.0003693675598799802, 0.006069757261227605, 0.013573443188956974, 0.022169782695716674, 0.03172716985517127, 0.041706313041201946, 0.05059765823908509, 0.05470010740717633, 0.04597772480070905, 0.011283302352042824, -0.05394899480371488, -0.1077856868705139, -0.08708703539491623, -0.01792351007958953, 0.033827244535284115, 0.05358331833806951, 0.05408591564244687, 0.04649233261899261, 0.03645057907263675, 0.026357245167394096, 0.017119582656726123, 0.009000824617856817, 0.0020244751960670957, -0.0038974480576836716, -0.00885900698228246, -0.012971820770344736, -0.016317049773601388, -0.018977263090709927, -0.021005621175010904, -0.02245308895758862, -0.02334842303058393, -0.02371510994341558, -0.02356180935836308, -0.022887156728454875, -0.021680634735397962, -0.019919552495557224, -0.017566233842589065, -0.01457925422684336, -0.010890838595468745, -0.006434298121069024, -0.001114200075842373, 0.005155139394341568, 0.012467947547515375, 0.020854456153504113, 0.030223026353079635, 0.04011524874770349, 0.04924593625724799, 0.05440969158532076, 0.048502438837719306, 0.019064708612103798, -0.042094971938907016, -0.10292764467040588, -0.0959082956095296, -0.029299392019044735, 0.02798290319165815, 0.052337495815545136, 0.05496633380650938, 0.04806031897792332, 0.03808613535063686, 0.027838065398344505, 0.018378726204999096, 0.010048695378434091, 0.0028750280776517556, -0.003206064631372672, -0.008310256322101396, -0.012533947698765497, -0.015979794296873653, -0.01871883338385365, -0.02082079948192038, -0.022328893507619525, -0.023282666157856888, -0.02370263344347037, -0.023602111651720357, -0.022982513970329375, -0.02183264715993439, -0.020135245305830305, -0.017853789038941373, -0.01494992578950297, -0.011359091890569, -0.007018559772880797, -0.0018357660775235834, 0.004271148815901312, 0.0113970863201934, 0.01957995083436698, 0.028754482909882105, 0.03853948727355842, 0.04783614140599905, 0.05387322859806078, 0.05040669636556338, 0.02583452146690842, -0.03045046838567243, -0.09596592856785918, -0.1029185274343826, -0.04117788755650165, 0.021127316327211782, 0.05053478968551449, 0.05564527586933594, 0.04959321642964968, 0.03974858051953059, 0.0293563340863845, 0.019684773731652658, 0.011129033483511476, 0.003759288977838738, -0.0024969547287470024, -0.007736997972381322, -0.01208447158140484, -0.015628438188120656, -0.018454926161330834, -0.02062534015538384, -0.02220002519928798, -0.023210989274702637, -0.023685360381779857, -0.023637822904417845, -0.023071737023486204, -0.021979455574656748, -0.020341444621746294, -0.0181344617155538, -0.015308203303235587, -0.01181185131606982, -0.007585312039466312, -0.0025356269513523723, 0.0034169155263233256, 0.010361854924290152, 0.01834242189449798, 0.027323424013246653, 0.03698354513353876, 0.04638440647529007, 0.05313684727076321, 0.05176484200771161, 0.03163678375669115, -0.019311536401308078, -0.08723440462658277, -0.10778234680084522, -0.05326729059924408, 0.013230855326842558, 0.048084466828950054, 0.05608196806262456, 0.05107378723427992, 0.04142892674985469, 0.03092235084585175, 0.02102856772185991, 0.012248605837710093, 0.004670785376332488, -0.0017547208282763047, -0.007149799015353018, -0.011618628680930047, -0.015267308997741853, -0.01817607509795862, -0.02042587180299276, -0.022064213968110962, -0.023134164315936793, -0.02366259410138147, -0.023668638808577835, -0.023155797225842344, -0.02211834095297614, -0.02054514570284108, -0.0184026962324165, -0.015656363942809545, -0.012251675247677715, -0.008133434397403442, -0.0032121922541800242, 0.0025892241607566465, 0.009358241137758285, 0.01714444316333266, 0.02593019821579364, 0.03545053937663307, 0.04490575995082134, 0.0522347477189875, 0.05265249352973666, 0.03652439889582552, -0.008880083447525621, -0.07714877461316376, -0.11026253622121546, -0.0651955283106553, 0.004279203789700722, 0.044907618326687514, 0.056221849700215314, 0.05248169931337637, 0.04312568512015838, 0.03252514254393217, 0.022418432138576166, 0.013406953386316946, 0.005617021566325474, -0.0009910111277474177, -0.00653999007212984, -0.01113652535359505, -0.014892087651353293, -0.017890488270987562, -0.02021686245401746, -0.021921919343105445, -0.023051579244228296, -0.023635151328203124, -0.02369433386669447, -0.023234656142699046, -0.022252670367504997, -0.020739289502431122, -0.01866378093655497, -0.01599383612349774, -0.012678156001141918, -0.008664737705381335, -0.004651319409620824])

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l, pred_waveform)
    push!(losses, l)
    @printf "Training %10s Iteration: %5d %10s Loss: %.10f\n" "" length(losses) "" l
    return false
end
callback (generic function with 1 method)

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob, BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback, maxiters=1000)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-0.00022478317259776824; -0.00014757248572988297; -0.00015495336265291423; 9.2691363533947e-5; -6.675012991755764e-5; 9.5317656814537e-5; 0.00012224247620891524; 0.0001455017772968316; 0.00010277838009639748; -7.804002962068974e-5; 8.976822573444471e-5; 4.12237568525263e-5; 3.4318280086154064e-5; 1.0803945770012727e-5; -9.722193499334526e-5; 1.8967344658442153e-5; -4.486825491769027e-5; 0.0001373958075419554; 3.924082648154162e-6; 4.6769564505614595e-5; 9.524104098075115e-5; -1.8312435713602697e-5; 9.288281034964838e-6; 9.24652995309927e-5; -0.0002862989495040917; -0.00023564558068731479; 0.00018884899327518083; -5.3668577493156445e-6; 0.0001441069762220741; 8.874541526886306e-5; -7.57755406083609e-5; -2.5291350539179012e-5;;], bias = [-3.822056322958917e-16; -1.9384610130869822e-16; -1.3217494707598195e-16; 1.2777509618778491e-17; -8.919634703214788e-17; 5.757513521288436e-17; 1.7278505284659062e-16; 5.795274456647421e-17; 3.086131014594439e-17; -7.027982582426247e-17; 1.6138865748634454e-16; 3.816736566167081e-17; 1.2701673613706582e-19; 1.0055497248964762e-17; 6.665060245157772e-17; 7.98236914353724e-18; 1.8769781478746176e-17; -5.761136158791074e-17; 4.458047996576838e-19; 5.995522898756678e-17; 1.793848289195624e-16; 6.690501936679815e-18; 4.968329600930291e-18; -1.1996644960202014e-16; -2.0273831106466704e-16; -2.80505292883811e-16; 1.5202947086244343e-16; -7.903509626795857e-18; 1.9550261105792043e-16; 7.796255865171515e-17; -9.751716399039147e-17; -1.4454942776312106e-17;;]), layer_3 = (weight = [0.00010002320874345015 -0.00012869199321276616 -4.391862506882428e-5 -2.6610880867433175e-5 -0.00011390884077546078 0.00013940304876168308 0.00012123695438357978 9.829745123073359e-6 3.626779385821924e-5 -0.0003069656003745916 6.470254544264701e-5 -4.41056790251457e-5 6.623852696881261e-6 0.00025593449487214135 -2.2498906163586777e-5 0.00021290599494104432 4.446071726212036e-6 -0.00010960886624080201 -0.00013266422951265072 4.98340314706149e-5 6.958857653525709e-5 -0.0001197090959422926 -0.0001276979664354289 -2.9618610288222193e-7 0.00017064363879479533 0.00010679625216734058 -6.101165758000973e-5 -6.729458158506846e-5 -4.555265962981717e-5 -4.612503011149289e-5 9.517656968536878e-5 6.093324287657781e-5; 3.0308701067611806e-5 -5.7507867271292674e-5 1.3832689770984597e-5 -0.00011073901692124338 0.00025416316497886153 -1.0775451986219312e-5 0.00026259102309869576 0.00014973889778649628 -0.00010505056411910986 -6.450940850608393e-5 -0.00014506327420657776 8.426665158058949e-5 -0.00013259574621169416 8.929052569019983e-5 -7.306769544210509e-5 9.644234748162916e-6 -3.984037971658188e-5 -0.00017194189854616187 -1.0334133324607744e-5 9.099305540628104e-6 5.732288559248434e-5 7.715294054524424e-5 -0.0002186626850319896 0.0001411621299181294 1.3227892165230799e-5 1.623096143770095e-9 -0.00012164097711627597 -5.375195772461357e-6 -2.305431985713364e-5 1.1527344437681132e-5 4.780872700085507e-5 -0.00026921622771023377; 0.0001410838637303852 -0.00010824501381364568 -0.0002740753219194422 2.3870302050566034e-5 -0.00013748106297214457 0.00013524124066236493 3.371778499342575e-5 0.00010791815717595712 2.7724456834332774e-5 -9.46929024729592e-5 9.05346356950055e-5 1.1350192528991598e-5 9.255676255918164e-5 2.3557250336236873e-5 7.920843623935313e-5 6.736402853050324e-5 2.1860527943429185e-5 0.00014652051746736358 -4.144589123541434e-5 -0.00010761398456168304 9.659721096170203e-5 -4.347965963177328e-5 -7.34781527506834e-5 -2.713294755892762e-5 3.9148206872708835e-5 2.184430207419894e-6 6.398153769907644e-5 -3.328244141528328e-5 -5.8131441016373346e-5 -0.0001157740874413115 -7.932507590965307e-5 -2.120104303401373e-5; -0.00015654863803404428 5.173964033844567e-5 4.619041345967553e-6 -0.00019821538428998987 2.908447112592899e-6 0.00010878431946281843 -0.00011646973856016441 -5.2301528111737395e-6 0.00011958818519730634 9.610868452825124e-5 0.00014536873656567675 1.3521869678399003e-6 0.00010315220440049152 -0.00024866000588275066 -6.195834415023643e-5 2.5924331720550078e-5 -0.00013902078338190484 0.0001330846354622453 -9.755761114405052e-5 1.577419986376043e-5 1.0868976613954054e-5 -5.3653333298764675e-5 -1.645762661695064e-5 -0.00020141785337812702 -5.155133824784991e-5 0.00011351903246493079 -4.446947407898677e-5 0.00014995254620686583 0.00019539382644240423 -3.426515088988911e-5 4.664452732783279e-5 1.9430801484365508e-5; -5.766804031588045e-5 6.10564177042225e-6 -0.0002665824766867303 8.630108197540253e-5 -0.00013964302059175503 -8.446780046501596e-6 1.5841121350248887e-5 -3.5325458955763976e-5 -5.1556388715101806e-5 9.587400120072115e-5 -3.8921699232010294e-5 -2.6363052600936493e-5 9.610859990207216e-5 -8.290252370353392e-6 -5.5249442883354785e-5 -9.192448203830362e-5 7.849905079783661e-5 -3.927141813472911e-5 4.230971819342705e-5 -0.0001074781511109901 -0.0002298148073154105 5.967189929069632e-5 -9.501749889605079e-5 -0.00017922642735210972 7.753286183438501e-5 0.0001472311515696109 6.512659740258084e-5 -6.849021136441625e-6 0.00011840128274490183 -7.874520531767818e-5 0.00014587236648513032 -6.405721513001546e-5; 0.00023816081743636522 4.875494783113776e-5 0.00011359736895605794 0.00019396179054816538 -1.0346737279210321e-5 -0.00015866967801051902 4.323446778649194e-5 0.00018281730813391625 -7.063723382820267e-5 -0.00011452193934750243 -2.03746876177671e-5 -1.4111350974713072e-6 8.063080536155028e-5 -1.1615603152959054e-6 -6.019648235814639e-7 -6.649126221239498e-5 7.464077139226646e-5 6.726266660912281e-5 -5.767444958840398e-5 -8.465559580875648e-5 0.00011733088836259448 -1.0292138493269898e-5 -2.7977088830366258e-6 -1.1274586493563295e-5 0.0001269558670645807 -0.000147418311778969 -1.1212896377460366e-5 5.7511882961961486e-5 1.798675662207571e-5 0.00015558693489943978 9.294510211442324e-6 -4.942180744185999e-5; 5.555193748250826e-5 -8.0175241957915e-5 9.89037045724372e-5 -8.801210314456957e-5 3.052281063676424e-5 2.022014924252729e-5 3.487966257600221e-6 0.0001446905200350725 -1.097594921927879e-5 -2.484354335363024e-5 0.00018342629228529895 -0.0002889338382305789 0.00013465262346246977 -2.5302054193589807e-5 -5.0205332438486364e-5 -8.997280740432522e-6 3.6629165340426454e-5 0.00012804934268162446 9.785453331235618e-5 0.00016550998880805162 -3.950834324539315e-5 -0.00011322395838985414 -9.81131550292689e-5 -9.468701572811812e-5 -3.065791569725946e-5 5.617395182298481e-5 2.4351721738194997e-5 3.1763295926366074e-5 0.000109850433734681 6.442766353532122e-5 -3.726696631244834e-5 5.548164797373133e-6; 4.6164985299707126e-5 5.161528145722188e-5 2.0061294101784027e-5 -7.742745389518755e-5 -1.3717133576253544e-5 -9.954350559096295e-5 -1.68396217003202e-5 -3.915910717162985e-6 0.00011855711466810177 4.38667996891314e-6 0.00015394410559852205 8.849519338869289e-5 3.671555726500975e-5 -2.0301473220429938e-5 -0.00012429135119132215 -7.177163631100379e-6 5.133614298133899e-5 -0.00010800173355787823 -5.959637734744843e-5 5.709347363622643e-5 -5.007708192974771e-5 0.00017321541066226393 -6.67167058662433e-5 -0.00013019299682750783 5.617914406059281e-5 5.301619797422454e-5 -0.0001318227822292634 0.00013977637548410188 -0.00016327289073783796 0.00012336481280103863 1.656314646711706e-5 0.0001686095985252487; -8.352232376581105e-5 -9.004865401891935e-5 -2.4646466189897125e-5 -0.00017855695310516176 -0.00011454158644426495 0.00013401061900423944 -4.1139405001904e-5 0.0001396353125897987 8.549450452903383e-5 -7.662235853729657e-5 -1.9241042681987255e-5 0.00017821326931936798 0.00013309843219814956 -8.922694374576077e-5 9.732369093455476e-5 -0.00014777423908470064 7.2496621584113475e-6 -0.00013150310238359432 4.55731440894133e-5 -0.00010368089548081032 0.0002966858390578824 6.1203103001580514e-6 2.881407439474152e-5 3.47828253381583e-6 7.403695948652764e-5 -1.3061822900800672e-5 -0.00011647063925864168 5.705385214298168e-5 5.13468165451015e-5 -2.8001632162503753e-5 -8.166231251323482e-5 3.160254877058923e-5; 0.000114737114506898 4.581331044192837e-5 0.00010036863344595172 8.079354259991037e-5 7.5342727885241685e-6 -8.797842803854264e-6 2.457393342212031e-5 -2.9946692359033353e-5 8.346052303551575e-5 4.120134648275597e-5 -0.0001581718392682731 -7.805300307352836e-5 0.00011209305333779403 -3.297250847224831e-5 0.00011541954112346203 -8.349741298132336e-5 0.00011933369345371793 -9.162487591510121e-5 2.602285610755969e-6 7.128671089760202e-5 0.0001697031689562362 -1.5355276101943368e-6 2.1732850234748483e-5 -0.00016322624329603332 3.4757394621727574e-5 5.033621304191798e-5 -0.0001138757292776004 -7.851660526483069e-5 7.829022755487339e-5 -0.00021620275261937668 1.2174537494336737e-5 6.658885062081038e-5; 3.804209428334639e-5 -0.0002675591158463686 -2.875879635261761e-5 7.589429243074336e-5 -1.7445468066008407e-5 -4.095401799583211e-5 -0.00015395938588931782 0.00012463918484122692 6.711465631218473e-5 -0.0001223087519889022 -0.0002467721415001201 -1.3581264275285135e-5 -2.342844435640972e-5 -0.00014865842876186154 -1.4799585178993282e-5 -2.656484774941108e-5 -7.83488494476609e-5 3.061804182784551e-5 2.2765126240686144e-5 -0.00010052617198250496 -0.00016817202321414238 6.150794161787724e-5 6.213401594270459e-5 8.14083187009913e-5 9.07565054235712e-5 -9.78866801147006e-5 8.18409444897993e-6 -8.379000700240524e-5 -2.594659326000456e-5 -5.399609925968122e-5 1.9726011495941376e-5 1.975349642582277e-5; 2.9483604942917197e-5 0.00019273442418445312 -0.00014670394333414513 -7.50176000442204e-5 -4.3327756008207744e-5 -3.012353366691073e-5 7.749475158882378e-5 0.00023461805218495888 -3.056906601748008e-5 3.0405896235181553e-5 0.00013767442984638438 8.02944818711118e-5 -3.9739552027151895e-5 1.7612414434132154e-5 -0.00014911375504786568 9.878411625644541e-5 5.763947816698093e-5 8.625455344692524e-5 0.00012202982585544147 -0.00013073100447987942 -9.106905541984947e-6 8.0861861045867e-5 0.0001234057676479444 -9.225920397190718e-5 2.511398494194713e-5 9.45994510980079e-5 -8.404521535021085e-5 7.140257788285675e-6 -2.8119625800610256e-5 -0.0001271396500091842 0.00014290710938760906 7.270685861252573e-5; 0.00012738851225187154 -3.857401455565662e-5 -1.030211982101535e-5 -0.00010694611549654684 0.00013400122267382986 3.982751738119174e-5 4.018725163956659e-5 -5.6416237870457154e-5 0.0001009218017007417 -0.00014003382378534765 -9.995447320209461e-5 -6.855230597822768e-5 0.0002944218271896544 2.583352258269087e-5 -8.68991060396451e-5 3.170699207301566e-5 -8.61291278251155e-5 -4.441508784470119e-5 2.5835932743650684e-5 0.00014255072025305693 1.800187388717928e-6 5.2557532822977205e-5 0.0001459222243884892 -9.254647150912957e-5 -5.3441826397849497e-5 0.00012130126421289955 0.00012356226076557246 0.0001189793024114648 -4.160327219682934e-5 0.000107815669676777 0.0001843135964609931 0.00012606390961635206; 0.00013535638218351533 -9.296765701539159e-5 9.942187811688636e-6 0.00020451754617578852 0.0001748503528729063 -4.285009335536035e-5 -0.00011431415977814108 -4.759746652372421e-5 -6.524400295848574e-5 7.22505757080363e-5 2.434034295628849e-5 0.0001168375712081788 4.305506454738799e-5 2.788505671969689e-5 4.5943226818502765e-5 -0.00010692588320625888 -8.729868195178026e-5 -2.165826289957493e-5 7.284320245571512e-5 0.00013915607642456842 -5.0179547447902056e-5 5.538173904369494e-6 -5.473675618804915e-5 -4.1778850477845e-5 1.236852049664486e-5 3.6161443953882555e-5 2.1847281554507305e-5 -6.931413726806564e-5 8.884730423641176e-5 6.0045753438766914e-5 4.6576529807222287e-5 -2.192559431523507e-5; 0.00012084980221915986 1.728608618068451e-5 -2.2579371257597542e-5 -0.00011350194585756733 -0.00016654233081280528 3.567082854206333e-5 0.0001088819651653482 0.00014828981034443198 -0.0001535182211642656 -7.459528319086247e-5 -5.048218849478617e-5 -0.00019637679838123968 0.00025270253786802995 2.0703017757481627e-5 -1.7354245979365637e-5 6.087258729534387e-6 0.0001786749660408566 -2.6931508951382873e-5 0.0001087997613962253 -0.0001589147552709585 7.25601300968289e-5 -2.916806010543585e-5 -2.6708931950994098e-5 -0.0001344630556974128 -4.0271502480762885e-5 0.00011876593157866895 0.0001348813700627929 0.00013301411195643965 -0.00011333009501467965 1.9927717179964776e-5 -3.856824445901732e-5 7.483518178927601e-6; 1.836043044787768e-5 5.653555734567222e-6 2.7152493825584434e-6 -2.7229568542273603e-5 7.807123690185998e-5 -0.00014067489176482358 -9.196070389549127e-5 3.113050913128683e-5 0.0001079994988398949 -5.401630670850378e-6 6.437354092526857e-5 -0.00014993923762772105 -1.833065583154862e-5 0.00016433293127961698 0.00011489151144673623 5.5871933699066164e-5 -0.0002568721240382258 -0.00014928684616421502 -3.3663770620066084e-5 -3.6163771466407005e-5 -8.998528966507168e-6 9.575979562983138e-5 0.00016444007703144344 7.385797988214675e-6 5.7790691033360045e-5 3.9252784307254955e-5 0.000232293940615296 9.587055290879324e-6 0.00010069125780177011 6.291332899167822e-5 -8.0030621785728e-5 0.0001979324158750867; -0.00013544185798345613 -2.491238023174879e-5 0.00012132332472866838 -9.498069424842443e-5 4.683056091346031e-5 5.872583847996801e-5 -7.148828603056439e-5 -6.789480698024319e-5 -7.023960073673194e-5 9.659512421387373e-5 4.154227850414871e-5 7.640466197665129e-5 -6.677965733649842e-5 -0.00010672030633975416 -1.6452959007783685e-5 3.2580761635129805e-5 0.00017469164436954344 -7.155854267728833e-5 0.0001183902406952471 0.0002043752813366361 -5.367544667058601e-5 0.00013282073474294836 0.000183898204577075 -1.6419929798193788e-5 -3.485982755629062e-5 -2.7414514752670266e-5 3.313936236701004e-5 2.5886360399111758e-5 0.00018559863952724817 -0.00014313397126932561 2.4933380936655687e-5 -0.0002614184429515959; 7.682114704515706e-5 -0.00022720234079747384 -0.00018443450932386342 2.2958351505856197e-5 -2.242301795747573e-7 0.00016329583053016663 0.00016886272390843703 8.873722293651764e-5 -1.936963195907382e-5 2.7809496200696753e-5 8.543711777316072e-5 0.00011046248294884244 -5.9353576548109573e-5 -5.1019021670744166e-5 -8.297462559078509e-6 8.126098731305407e-5 -2.040327086864611e-5 -9.018592351344228e-5 0.00010328542742650925 5.142720451895181e-6 5.4021198989409934e-5 0.000142855773842949 3.3919252414544715e-5 -0.00010447070195800047 -4.963626056244732e-6 -9.408929287452518e-5 6.36969202198509e-5 -9.165624175215172e-5 0.00018064498766330188 -0.00011941498775262715 -2.1236726356384028e-5 1.789735546050073e-5; 8.109829262169502e-5 0.0001578535336469125 -0.00010178218896697928 5.764855425722941e-5 -3.292095449951746e-5 2.4333923402354928e-5 -6.048531237838265e-5 -1.1758367188032645e-6 -0.0001264972927580848 -8.098389323067995e-5 -6.462287662516344e-5 -8.74291693159706e-5 -0.0001400358853695314 -1.6187814432739175e-6 1.557826520211172e-5 -3.657819640439583e-5 9.313613166368114e-5 4.523631990224015e-5 -3.364136889077396e-5 -0.00016472409722128394 -0.00010589773980400444 4.2333140054590785e-5 -7.118366074526161e-6 7.70679068685234e-5 -6.18185115578304e-5 6.955444744593335e-5 8.227221745102896e-5 -0.00014963227683411587 7.839560321653603e-6 -1.642548729223675e-5 -0.00021802924699814227 9.328712961204725e-5; 7.796594413250542e-6 3.774703347011789e-5 4.317986005339004e-5 4.7337680566174275e-5 0.00026419200204550884 -1.0782178638279017e-5 7.71341894637314e-5 -4.5298480277374497e-5 7.824414407373575e-5 0.0002353647889453501 -9.029826995678834e-5 0.00013119900742318343 -3.489852841588771e-5 0.000193439615226872 -5.043962827180262e-5 8.112470296915574e-5 0.00010123692683415324 2.032045469099712e-5 -0.00015517776572372786 -0.00017320315035753125 4.438120153694614e-6 0.00021859593308769126 1.2830990630947234e-6 -0.00010386574811995652 3.0244226049397378e-5 -0.0002494208436342719 0.00012114278271331731 0.00011374755032598948 -7.98116687797551e-6 -0.000126287048065519 -7.662833175849592e-5 -9.346107052400774e-5; -6.749136513738323e-6 6.366978634624513e-5 -0.0001126169314141324 -6.852712392563618e-5 0.00011315014710427746 0.0001104549650569569 0.00016307456683144957 -0.00017087537123578803 -4.118919748569606e-5 -2.381340125273962e-5 -7.529847932736053e-5 -0.00010907294342749154 -0.00017070745668595972 5.0269837107152784e-5 -2.71309596194299e-5 1.45409552949057e-5 0.00011073824719071043 3.1211285892927146e-6 -6.245158290246766e-5 1.9836485467493724e-5 4.3253554782237085e-5 8.933841292853905e-7 0.000150075049919651 0.00022314696836994344 6.499523299288101e-5 4.420903717409135e-5 0.00011729956752397252 -2.1808425639646916e-5 -5.8203351340388875e-5 2.6499691035238737e-5 -6.322713450052344e-5 5.3269648947861866e-5; 9.01021143913382e-5 9.836680916168636e-5 4.8004871551252626e-5 -7.424827848762574e-5 -6.683228779319745e-5 6.224436443563222e-5 8.455335997541718e-5 8.081028723654279e-5 1.750256929927098e-5 0.00017355064706908377 -5.7061993665679274e-5 1.3255993017545746e-5 1.8854449499949018e-5 7.170113754956216e-5 6.824474403860535e-6 0.00015031372682688908 -0.0002455508270390884 -0.00014916889407641184 1.783936428229029e-5 0.0002150410823201207 1.550697967150388e-5 -1.5413542805154214e-5 -0.00026166404635613615 3.721902108782604e-5 -2.3735353842802713e-5 2.01555289645299e-5 9.584170989954924e-5 -8.438183207498693e-5 4.860821942247415e-5 7.244054946614167e-5 -3.430620339667569e-5 4.650578417614237e-5; 0.00015264520703331774 9.476257042244746e-5 -7.213092971376515e-5 0.00011528254622810903 0.00012122856065382832 0.00012555582730549737 1.787330443896891e-5 6.777779713722646e-5 -4.265195815839519e-5 -8.333293584889228e-5 -3.831216365660666e-6 0.0001370705104431489 -0.00012885922109723277 -7.287988041076948e-5 -5.62130298408832e-5 -3.5657780632077335e-6 3.894043407620632e-5 -2.1583489744967128e-5 1.2218195546981376e-6 0.00024007924008168045 -3.5765123744296496e-6 -0.0001398017956876769 6.0444555908756825e-5 -3.291762181340334e-5 6.464435908954553e-7 0.00015589297623355832 -9.556094659509143e-5 -2.4317025183850368e-5 -6.233643575121859e-5 -0.0001044513809961877 9.636148393604495e-5 -5.9952802932971845e-5; 0.00011045157393598999 0.00012270212559888942 6.782166551493705e-5 0.0001670232285039885 -1.5503956229779543e-5 3.989634737883775e-5 1.554796640340724e-5 3.0879558697789144e-5 7.637881012562612e-5 -5.580772099271997e-5 0.0001150052029050345 0.0002579023547230735 -0.00010239026064588824 -5.310890766093929e-6 -0.00015316251861002242 -0.00020709274590550983 0.00013541265602619174 -0.00019526972223610488 -2.910311755037434e-5 0.00010601699418551058 0.0001114402165047274 9.827452094060682e-6 1.090747705241626e-5 -7.745626694593447e-5 -4.08885002911989e-6 0.000109078244212266 4.221048752587769e-5 -9.519265781925094e-5 0.0002223312602371253 -4.726231977763396e-5 3.0196884698694668e-5 -2.7733720481745865e-5; -4.7047835761058333e-5 9.087153877684743e-5 8.628420757217083e-5 -2.6045462146566196e-5 7.784701439873401e-5 4.398116790894206e-5 -8.108353948132728e-5 2.8172163655692046e-5 -0.0001250612535668102 2.286256814755016e-5 -0.00010568518198938256 7.807618523570974e-5 -8.600135195678733e-5 -6.760944823412787e-5 0.0001064409530218201 -2.898532009652927e-5 6.201688804053947e-5 5.154656311701134e-5 5.5698872091904335e-5 -0.00013348416403545762 2.3039766819284042e-5 -8.376921185203474e-5 -3.0341047459790485e-5 -9.989931866671139e-5 7.840870377458338e-5 0.00011789747943022782 -0.00015724512784888112 5.625568293822375e-5 0.00020822006765142042 0.0001308966907518437 0.0001605200391341467 0.00014289323145845315; -4.9459981453028164e-5 -1.8543489360771785e-6 2.219107162348121e-5 2.303479348742051e-5 -0.00023348792157517852 1.4645539779701183e-5 3.6304048337873224e-5 0.00010657900136978192 -1.6851827892572655e-5 0.00015320454557044228 9.220484506189503e-5 -5.0301995285879224e-5 -7.104731477480293e-5 0.00014492607741234975 8.499024559827117e-5 5.663657000954402e-5 -1.5978988286933201e-6 0.00016457305471961355 -5.470657212896772e-6 -2.443273362328734e-5 3.954093634350777e-5 3.9644418650680693e-5 -1.6457232587226184e-7 9.041312594748683e-5 4.842267961320397e-5 -6.166672093713084e-5 3.898516959709012e-5 -1.3342415256517286e-5 1.3302149712062376e-5 6.466775383418116e-5 -3.087690194403003e-5 -0.00015858612888874362; -0.0001384795792828061 -3.461765266455821e-5 -3.328384594259763e-5 -8.995619856251511e-6 -9.257391105257183e-5 3.725932799079521e-5 9.085520684402847e-5 -4.7668632424568374e-5 -2.0147510104834436e-5 -0.00022106673316648346 -3.2026174841122654e-5 -0.0002248081906427667 0.00022887765623544868 -4.879504161272664e-5 0.00015902171105062677 -9.381656457754304e-5 0.00013191414142874644 5.129375164779551e-5 -2.6481740525020366e-5 0.00012022763655203982 9.565776087492769e-5 9.666642687897675e-5 5.80200033882586e-5 0.0001666238061889078 7.123875217898485e-5 -0.00016690155064768127 -9.391915557988207e-5 -0.00010263616025532135 -0.000136566730026055 -4.808225970100618e-5 -0.00010308046660702808 -0.00012316317918794668; 0.00020088504527973147 5.786720920244233e-7 5.7373089595195854e-5 -4.214418810866856e-5 4.3590155775142766e-5 -0.00022357199929379904 -0.0002454625326833185 8.443626132370796e-5 0.0001298787566441583 -3.3496501481442546e-5 4.966067200566133e-6 3.345879952697438e-5 0.00010701972781256157 2.267683624389355e-5 2.2879880936071068e-5 0.00015137570438249293 9.144874923671043e-5 -0.0002364868095070524 -2.8538662144152916e-5 5.214202167856895e-5 -0.0001992421429064595 0.0001493022019816029 5.497679297271422e-5 3.475962222491819e-6 -0.00020486414438764962 -0.0002536296495860579 -0.00011822030216411684 -0.0001945819793660729 1.3297484775966513e-5 2.7497802085544505e-5 -3.256909337202374e-5 -5.9323513477162476e-5; -6.900130903235195e-5 -6.069981610640096e-5 4.496624968592051e-5 -4.329508520901113e-5 1.8664788865077888e-5 -7.748945573699547e-5 -0.00011500542824510988 2.5770585918024673e-6 5.107935647824043e-5 5.3725487657300715e-5 -0.000297243806574316 -4.0254287899064775e-5 0.00014994878655410067 0.00018190493791474604 -2.9404321697129444e-5 -0.00010817186340579177 0.0001130772892676175 3.4653896716140966e-5 5.70149844744579e-5 -4.8681134660786635e-5 -2.8850493993432037e-5 0.0001024186370951927 -6.382647335418859e-5 -0.0001451047261187891 6.267088100428933e-5 -0.00015413861330019536 0.00015789990937915287 -0.0001831598761899526 3.690750817181345e-5 -0.00018185138252452894 -5.4927726171501584e-5 0.00017903741025942046; 0.00011603602964961249 -4.498691137930351e-5 -4.595378792080195e-5 -4.6118504687242e-5 0.00017863526212817126 9.288613765179575e-6 2.9850052524790544e-5 -0.0001006221974960429 7.325152895712952e-5 2.4251295909140936e-5 0.00016910978214649346 -7.084553741084192e-5 4.661472834481176e-5 -2.231698917293504e-6 -2.2655702340912346e-5 8.041442546595182e-5 -0.00010055779799518905 -8.505479869133603e-5 -2.738344764347179e-6 -8.694288058837683e-5 5.125295123501488e-5 -0.0001510545008925505 -0.00015197306143751056 3.52426939941647e-5 0.00010522259603084442 -1.7700562383919053e-5 9.769125958038674e-5 4.4848344114817094e-5 0.0002555759236572411 -2.5453713678201987e-5 0.00014188338158486406 -2.2846161445395082e-5; 9.625290345994122e-5 2.564130096417072e-5 7.4000238384005e-5 -0.00011944392957829107 8.52593898902449e-5 -7.425262941850522e-5 3.139407505060309e-5 -1.2133682902085872e-5 0.00015838926861889472 0.00012898869550682468 8.380908781092452e-5 1.1603764106814314e-5 -5.528621330620633e-6 6.567608448805215e-5 3.965007691850749e-6 2.1519712901328867e-5 -0.00012693723459832463 -9.208710020046965e-5 -0.00013727648678340087 5.8879361460105496e-5 7.652961868562878e-6 -0.0002748741002856947 0.00011336234689352962 8.838869922457095e-5 0.00010150321992241724 -0.00014425917964999322 5.402098264440425e-5 1.2523052247536118e-5 -0.00015310028305847497 5.1685978688884794e-5 -1.848342202591845e-5 -2.2977781044198986e-5; -0.00018165962074053902 -4.426456262227373e-5 -0.00013530526603209024 -5.992331471280754e-5 -5.9082435929316226e-5 -2.4352842321252336e-5 -0.00010931738625150574 2.148243465707396e-5 -7.478433081051843e-5 -6.042555951500003e-5 8.896184472031318e-5 -5.759195965003025e-5 2.6394575523577605e-5 4.832257281885639e-5 0.00011701588546055762 5.383245906922198e-5 3.625016570220906e-5 -2.946315844529607e-5 5.218541873399322e-5 -0.0001551438484071826 -2.3275682271207575e-5 -2.3576066543332543e-5 4.177225009551428e-5 0.00010307879592654665 -0.00013610697469775194 0.00013519186058911196 2.21579745797453e-5 5.082291199722482e-5 0.00014279306806055575 -0.00015586054478409345 8.511880759978035e-5 8.525588664123027e-5], bias = [3.8659063332054014e-10; -1.0811649132584174e-9; -3.9414660455054684e-11; -2.836085814411736e-10; -1.95439352062851e-9; 2.71761743452508e-9; 1.4526517857746362e-9; 9.95108306819811e-10; 1.1081339926925013e-10; 1.0909732921970337e-9; -3.760752021718679e-9; 2.572472602214027e-9; 3.52016839555818e-9; 1.786342636557036e-9; 4.390625727643187e-10; 2.1543823579688095e-9; 1.0679829061409475e-9; 1.1499971707131035e-9; -2.3249137555329097e-9; 2.155697494532595e-9; 1.2500329102982973e-9; 1.682984029719651e-9; 1.4952615262742952e-9; 2.777459332402429e-9; 2.0095458757666226e-9; 1.8219751321917225e-9; -1.8368437002363364e-9; -2.194491268616344e-9; -1.6130915041095554e-9; 1.7133431727004091e-9; 3.4429959881601205e-10; -7.724353190069279e-10;;]), layer_4 = (weight = [-0.0006651270612979587 -0.0008006936919792633 -0.0005974013288984016 -0.0008263879946612265 -0.0006843621540455762 -0.000354431499178204 -0.0005893588681578035 -0.0008823306683941675 -0.0005671106366976779 -0.0005519870712865922 -0.0005652812716961771 -0.0007178987703720384 -0.0008776086878910871 -0.000885933062276219 -0.0006088869906303248 -0.0007983922416704145 -0.000623817082399607 -0.0008330975572876093 -0.0006729864833351184 -0.0006249519714253154 -0.0008858838952209818 -0.0008027015512595371 -0.000748740999652364 -0.0006836055787761119 -0.0006623275369962039 -0.0007744368284596947 -0.00041175495733159704 -0.0007623288307687365 -0.0007341506784588186 -0.0007193746225672538 -0.0005415628769880931 -0.0007666481268364533; 0.00030138533859586347 0.0002652340386674532 0.000440936147869538 0.00026646054617491253 0.0002786821421738702 0.00014552512372937584 0.00036848327839815177 0.00016157007054749825 0.00015677651898787486 0.000235916321124889 2.8877216629558145e-5 0.0003451246835329103 0.00027439844174550435 0.00037533468552238594 0.000332436294983669 0.00023954857382322878 0.00025548399995594864 0.0003054457699921054 0.0001963936345605925 0.00014355639845019182 0.0002056191707323231 3.7935949157669475e-5 0.00010874812092629262 8.632761727885637e-5 2.562647496320392e-5 0.00021447451207654757 0.00033102881495165794 0.0002598604429110162 0.00028015471195700275 0.0002790952677666988 0.0002025552299511393 0.00031378689851788133], bias = [-0.0006696695742147018; 0.00023256220966125143;;]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; markershape=:circle,
        markersize=12, markeralpha=0.25, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(compute_waveform(
    dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(ax, tsteps, waveform; markershape=:circle,
        markeralpha=0.25, alpha=0.5, strokewidth=2, markersize=12)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(ax, tsteps, waveform_nn; markershape=:circle,
        markeralpha=0.25, alpha=0.5, strokewidth=2, markersize=12)

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(ax, tsteps, waveform_nn_trained; markershape=:circle,
        markeralpha=0.25, alpha=0.5, strokewidth=2, markersize=12)

    axislegend(ax, [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb)

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(LuxCUDA) && CUDA.functional(); println(); CUDA.versioninfo(); end
if @isdefined(LuxAMDGPU) && LuxAMDGPU.functional(); println(); AMDGPU.versioninfo(); end
Julia Version 1.10.2
Commit bd47eca2c8a (2024-03-01 10:14 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 6 default, 0 interactive, 3 GC (on 2 virtual cores)
Environment:
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PROJECT = /var/lib/buildkite-agent/builds/gpuci-1/julialang/lux-dot-jl/docs
  JULIA_AMDGPU_LOGGING_ENABLED = true
  JULIA_DEBUG = Literate
  JULIA_CPU_THREADS = 2
  JULIA_NUM_THREADS = 6
  JULIA_LOAD_PATH = @:@v#.#:@stdlib
  JULIA_CUDA_HARD_MEMORY_LIMIT = 25%

CUDA runtime 12.3, artifact installation
CUDA driver 12.4
NVIDIA driver 550.54.15

CUDA libraries: 
- CUBLAS: 12.3.4
- CURAND: 10.3.4
- CUFFT: 11.0.12
- CUSOLVER: 11.5.4
- CUSPARSE: 12.2.0
- CUPTI: 21.0.0
- NVML: 12.0.0+550.54.15

Julia packages: 
- CUDA: 5.2.0
- CUDA_Driver_jll: 0.7.0+1
- CUDA_Runtime_jll: 0.11.1+0

Toolchain:
- Julia: 1.10.2
- LLVM: 15.0.7

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 25%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 3.725 GiB / 4.750 GiB available)
┌ Warning: LuxAMDGPU is loaded but the AMDGPU is not functional.
└ @ LuxAMDGPU ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/LuxAMDGPU/sGa0S/src/LuxAMDGPU.jl:19

This page was generated using Literate.jl.