Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux, ComponentArrays, LineSearches, LuxAMDGPU, LuxCUDA, OrdinaryDiffEq, Optimization,
      OptimizationOptimJL, Printf, Random, SciMLSensitivity
using CairoMakie

CUDA.allowscalar(false)

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end
one2two (generic function with 1 method)

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end
soln2orbit (generic function with 2 methods)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end
d_dt (generic function with 1 method)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0)<1e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio1 "mass_ratio must be <= 1"
    @assert mass_ratio0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit1, mass1, orbit2, mass2)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end
compute_waveform (generic function with 2 methods)

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; markershape=:circle,
        markersize=12, markeralpha=0.25, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defiing a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(Base.Fix1(broadcast, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1e-4)),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1e-4)),
    Dense(32 => 2; init_weight=truncated_normal(; std=1e-4)))
ps, st = Lux.setup(Xoshiro(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-0.00023428997; -8.684048f-5; -4.0255204f-6; 6.54771f-5; 6.415647f-5; -2.7389971f-5; -5.474401f-5; 0.00010381014; -4.8548434f-5; 0.00020308526; -0.00019101764; 8.481202f-6; 7.355556f-5; -4.6032477f-5; 2.7787542f-6; 0.0001706191; -1.662709f-5; -0.00018833872; 6.941935f-5; 0.0001393228; -5.3570046f-5; 4.333676f-5; -5.073853f-5; 8.378465f-5; 6.311032f-5; 1.1791431f-5; 7.488557f-5; 1.2126833f-5; 6.059752f-5; 4.7191766f-5; -0.00012270494; -9.204919f-5;;], bias = Float32[0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0;;]), layer_3 = (weight = Float32[6.93335f-5 -2.8259301f-5 -4.3718514f-6 1.2087882f-5 1.2790806f-5 4.1278905f-5 -6.25385f-5 4.9532886f-5 -4.9188522f-5 -0.00016169736 0.00011241832 -0.00016420771 -6.211246f-5 -0.00014722352 -0.000102415 -4.0479795f-6 -0.000100073565 3.2051797f-5 -1.8884185f-6 7.028216f-5 -0.00010190853 -6.48466f-5 9.341909f-5 0.00012587551 -0.00018010645 -0.00015517046 5.9764443f-5 -8.7236425f-5 -0.00013064282 0.00020178649 -0.000104049905 -0.00010623942; -0.00019221044 5.4364085f-5 6.813449f-5 -7.4336895f-5 1.2572007f-6 0.0001748893 8.945907f-5 -6.748034f-5 3.164656f-5 -6.1754596f-5 9.4020426f-5 -9.157712f-5 -6.442088f-5 2.5986996f-5 -6.1389444f-5 0.00016375378 -9.860361f-6 9.062229f-5 0.00011687793 -0.00017562903 1.5843998f-5 -3.684154f-5 -8.151624f-5 1.1671711f-5 0.00010124109 0.00011909989 -7.5952585f-6 -0.00012696245 -3.0207932f-5 0.0001471766 5.5959736f-6 -8.401232f-5; 2.0168849f-5 4.2285497f-5 -4.9132817f-5 -1.08513f-5 0.00015082704 5.2320822f-5 -0.00023840588 -2.9537963f-5 4.3317774f-5 -6.804775f-5 6.293456f-5 0.00013149001 7.382045f-5 -9.062325f-6 -0.00014242924 4.266049f-5 -2.473969f-5 0.00016665127 -7.0030495f-5 4.951238f-5 -8.70818f-5 -0.00013093735 0.0001208161 1.1110255f-5 5.870077f-6 0.00010138293 -3.2118246f-6 0.00016883429 -4.7281268f-5 -0.00019771006 0.00010620622 -1.705842f-5; 4.0656258f-5 8.265579f-5 -0.00013649129 -2.3457318f-5 0.00010383852 8.459596f-5 -0.00011749194 -0.00014663543 -0.0001960738 -5.5585042f-5 -0.0001268352 -0.00011515475 0.00021179496 -1.7612923f-5 6.494667f-5 -0.00012559911 6.473921f-5 -0.00016921679 1.6622f-5 8.630804f-5 7.198261f-6 6.9831265f-5 7.498943f-5 9.545437f-5 0.00011362288 -7.1051f-5 -9.413723f-5 -0.00014324262 -0.000115353556 5.0161965f-5 -0.00013650699 -0.00019545533; 0.00015553106 -0.00010599973 -6.199864f-5 -2.6860873f-6 -7.593437f-5 0.00011171458 5.9082875f-5 -2.8083514f-5 -0.000102789985 -8.966962f-5 -0.0001025491 0.000112947695 6.49046f-5 7.701169f-5 4.9815453f-6 7.139885f-5 -6.2451596f-5 -0.00019520447 0.00010964626 -0.00021804302 -4.1063202f-5 4.4094336f-6 -5.0409044f-6 -0.00016371181 -0.00011560487 -5.692295f-5 -1.7540546f-5 -0.000112492606 -3.183555f-5 4.9679475f-6 3.8568345f-5 -9.906378f-5; 4.5333563f-5 -7.409835f-5 0.00012180001 8.39162f-5 -1.2581061f-5 6.9666195f-5 -4.563049f-5 3.1231513f-5 -6.2666215f-5 -6.745029f-5 1.3924931f-5 -8.916947f-5 9.610091f-5 3.7270744f-5 -9.981516f-5 -0.0001164734 3.5014036f-5 0.0001444317 -4.085073f-5 -5.040203f-5 -0.0001374248 5.0660125f-5 0.00012238167 4.860329f-5 -0.00016008494 -8.924109f-5 0.0001499602 -0.00017366419 -2.3955545f-5 -0.00018774581 -4.8112423f-5 1.1697931f-5; 0.000117442854 -0.0001539785 0.000110427405 9.489322f-5 6.6295666f-5 -3.7373244f-5 -5.780035f-5 -5.7262223f-5 0.00012026026 0.00020116006 -0.00018465439 4.5645847f-5 -0.00010803395 0.00012548818 0.0002774758 0.00020159109 -5.3965152f-5 -1.522673f-5 6.903063f-5 8.950893f-5 4.9190243f-5 -0.0001982525 8.658844f-5 -2.4200503f-5 -0.00019207531 7.875197f-5 -3.857776f-5 -0.00010856375 0.0003154384 3.009129f-5 7.757821f-5 7.5701755f-6; -5.1373368f-6 -4.5006975f-5 -0.00011952661 6.483447f-5 -0.00011211344 -6.8628644f-5 -0.00030121836 -7.4975615f-5 4.8555008f-5 -6.576244f-5 -8.112424f-7 7.9809416f-7 0.00021076875 0.00025370353 -6.747062f-5 0.000104288054 -0.0001330641 5.269299f-7 -6.659119f-5 0.0002268003 0.00016957258 5.5934845f-5 0.00014086493 8.1018115f-5 3.504692f-5 -6.804393f-5 -1.4535145f-5 -0.00016971912 -0.00017717623 -2.4037815f-6 -3.6534788f-5 -0.00017220768; -4.0088365f-5 2.2576285f-5 -0.00014946608 0.00017981495 -5.5764867f-5 7.7901415f-5 2.018736f-5 0.00011442037 0.00011676731 -0.00014905383 0.0001827995 0.00014134498 -7.51587f-5 -5.4459128f-5 0.00013919703 -2.2351905f-6 3.8121372f-5 4.8703423f-5 -4.268227f-5 -7.916436f-5 -7.509714f-5 -1.7292563f-5 -0.00011611051 -3.467507f-5 -7.9898164f-5 5.3810694f-5 -7.32715f-6 -7.342673f-5 2.1173226f-5 -2.7538557f-5 0.00012367067 7.8096746f-5; 0.00012244977 -0.00012651413 -0.000105239415 -6.914186f-5 0.000119714095 2.0306434f-5 -0.00019741221 5.8113797f-6 -7.004156f-6 -4.1955274f-5 5.5510813f-5 2.001571f-5 3.737427f-5 -0.000120754776 -0.00010930518 -7.2193696f-5 3.5429644f-6 5.732171f-5 -5.2554373f-5 0.00015531399 0.00017213526 -7.1459166f-5 -7.907205f-5 0.00015182521 0.00011143393 -6.617704f-5 -2.9541208f-5 0.00013767075 0.00010330746 -2.2024558f-5 -6.533371f-5 -5.3928625f-6; 5.348635f-5 -1.6156211f-5 -0.00016110331 9.911393f-6 5.9455324f-5 -5.956128f-5 -0.00012497106 0.00013109288 -0.00013277419 -2.1312426f-5 0.00016364503 -0.00011309065 6.1213555f-5 -6.307199f-6 2.9900315f-5 -3.7051326f-5 5.3767908f-5 7.1247654f-5 1.2117627f-6 -0.00022506645 -4.7030935f-5 -4.1147174f-5 1.7486403f-6 6.807168f-5 3.7923233f-5 -5.660332f-5 0.0001978147 0.00014059289 6.852399f-5 0.00012977919 6.5685144f-6 -6.708666f-5; 1.2822967f-5 -6.0555194f-5 -0.00010709167 -0.00016187107 -6.2703984f-5 0.00014187321 -0.0001802192 -0.00017303046 -1.6721533f-5 7.83033f-5 4.500443f-5 0.0001168085 -2.24445f-5 -0.00018828029 5.810844f-5 3.460103f-5 7.893916f-5 3.6554306f-5 -1.6298899f-6 2.2341228f-5 -0.00010465144 -0.00020032916 8.5593696f-5 1.7055525f-5 -1.0325211f-5 0.00020142904 -8.8050714f-5 0.0001336565 -0.00012591254 -0.00018896715 -2.2715702f-5 0.0001053404; 0.00018769769 -5.7848756f-5 1.7897784f-5 8.9528716f-5 -5.4972537f-5 -0.0001534091 0.00012186858 -3.912847f-5 -6.117687f-5 2.9148794f-5 -1.338809f-5 0.00026159064 -4.3187825f-5 -0.000102122496 0.000108954824 -5.3548133f-6 -4.079067f-6 -3.891172f-5 9.789185f-6 0.00012814111 -0.00013263858 -8.640494f-5 -5.9314887f-5 -0.00012747929 -0.00019097015 6.9157606f-5 9.939574f-5 -2.145523f-5 -8.660185f-5 6.0501217f-5 6.4834916f-5 0.00019948947; 2.6010812f-5 -8.01058f-5 -3.4125645f-5 8.817442f-5 -3.717172f-5 0.0001283602 9.238838f-5 7.4219715f-5 9.470325f-5 2.2261365f-5 -0.00014238802 2.4803008f-5 0.00011991973 3.9184606f-6 1.8409894f-5 3.0330353f-5 -4.0784933f-5 -0.00014598902 -7.9666504f-5 -0.00013549428 6.3655098f-6 -7.489001f-5 -0.00012596078 3.5003042f-5 -1.2829264f-5 2.9132301f-5 9.818763f-5 0.00012626195 -5.2304684f-5 0.00013699192 -3.1170166f-5 -1.9446326f-5; 6.4800166f-5 3.5773664f-5 -6.7415196f-5 -6.605483f-6 -6.838318f-5 9.116646f-5 0.00014036418 4.9606864f-5 -0.00017349845 1.7967746f-5 7.359749f-5 -1.4848573f-5 6.9998365f-5 -6.650569f-5 -0.00012796788 0.00017906626 0.00020262408 -2.8910505f-5 -1.2437257f-5 -0.0001125372 -3.014834f-5 9.870737f-5 -5.92879f-5 9.538057f-6 0.00018714704 -7.239343f-5 -0.00017413363 -5.9509657f-5 6.002056f-6 -9.623122f-5 -1.7231661f-5 -0.00016628364; -0.00010035637 0.00013958734 -0.0001783269 -0.00012100978 0.00017382295 7.326824f-5 -0.00017349153 6.532829f-5 7.0404094f-5 1.7303346f-5 -0.00016059025 8.513046f-5 -0.00017532932 -1.3121186f-5 -9.755843f-6 -0.00019052887 -2.2395458f-5 -1.4971898f-5 2.294556f-5 7.3762065f-5 8.951538f-5 -2.7929384f-6 0.00014612511 -7.829575f-5 3.0058807f-5 -2.5269306f-5 -7.8209145f-5 5.611693f-5 -0.00014208518 -2.6588947f-5 1.995017f-5 -2.1708476f-5; -5.2535557f-5 9.461455f-5 8.808028f-5 0.0001159127 0.00018574178 -0.00012173049 -0.00012588759 3.1211246f-5 3.998778f-5 0.000136192 9.88793f-5 -1.11749605f-5 -6.599249f-5 8.0051526f-5 1.0382291f-5 -5.222838f-5 0.00011758656 -0.000115685056 3.272994f-6 7.539913f-5 -3.6131954f-5 -8.759615f-5 6.47246f-5 -9.366825f-5 -9.6453274f-5 4.5589048f-5 3.8763803f-5 1.6735352f-5 -4.9514172f-5 6.726272f-5 9.894188f-5 8.496692f-5; -4.986115f-5 -3.3543995f-5 9.1382106f-5 -4.9094583f-6 0.00018497008 -3.3190052f-5 -5.370984f-5 0.00010369876 2.5382236f-5 -4.757402f-5 -4.2074946f-5 -3.240842f-5 -2.906429f-6 9.237654f-6 -0.00022805954 3.2784214f-5 -9.384264f-5 4.550893f-5 0.000118474396 -5.9756334f-5 2.8129489f-5 2.892558f-6 6.249213f-5 -0.00010260177 -0.00017726395 6.191182f-5 -0.00014433048 -8.74536f-6 -3.579979f-5 7.3822803f-6 -0.000111779715 -6.620172f-5; -0.0001199331 9.360696f-5 -4.6767782f-5 2.6145415f-5 9.059882f-5 -2.2083863f-5 -1.9870771f-5 8.354823f-5 -0.00017450865 0.00012520942 -0.00012061897 -6.1707055f-5 -0.00017167871 0.00020953128 0.00011374933 -6.1129995f-6 -8.032036f-6 -0.00021742695 -9.396456f-5 -3.1158127f-6 -3.9793304f-5 5.439552f-5 7.053785f-5 -4.0675553f-5 -5.4497963f-5 0.00011538134 6.805109f-5 -5.2184667f-5 0.00017975733 -5.2918953f-5 0.00014115646 -3.6905156f-5; -7.326965f-5 -3.5871875f-5 -9.010265f-5 -9.718869f-5 7.1585535f-5 -7.877674f-6 3.059647f-5 -5.8541677f-6 -2.7653037f-5 -2.4688317f-5 1.3117494f-5 -0.0001066594 -7.981335f-5 -6.553137f-6 -4.787454f-5 -8.195723f-5 -4.519638f-5 -1.2516173f-5 2.2863365f-5 5.852418f-5 0.000109988556 8.9883564f-5 -0.00014783039 0.00022573174 -0.00015021885 6.457383f-5 -2.024331f-5 -6.796181f-5 -7.71419f-6 7.950469f-6 -0.0001453032 -0.00022172176; -4.74446f-5 -0.00020778227 0.0001052762 -0.00010109279 6.7789288f-6 -3.8156657f-5 -0.00020990093 -7.314523f-5 -8.0334896f-5 -5.930037f-5 6.5404434f-5 2.1062768f-5 0.00014021437 3.460533f-5 -0.000105929714 5.2461914f-6 -1.0135821f-5 -0.0002402825 7.4796604f-5 5.0998377f-5 -9.8891316f-5 -2.6579954f-5 -0.00014874242 -4.0376308f-5 -6.5672066f-5 -0.00010701795 9.5799995f-5 -2.5632711f-5 2.2789189f-5 -7.216266f-5 -6.686478f-5 0.00015318823; -6.494366f-5 -0.00013367632 7.0767026f-5 -0.00011029089 1.0926776f-6 -1.5869886f-5 -5.3721662f-5 7.793729f-7 -4.851346f-5 3.6120266f-5 -9.74303f-5 -0.00016123072 -2.3119126f-5 -3.8194084f-5 -6.594418f-5 -0.00013025354 -0.00013236843 3.079631f-5 -0.00013379796 3.4496232f-5 -5.4113116f-5 1.6564318f-5 -3.9254875f-5 -8.467732f-5 -5.7130957f-5 -0.0001645787 6.8971945f-5 -5.5042445f-5 -6.9398135f-5 -7.7660883f-7 5.194813f-5 0.00012876176; 4.9289803f-5 0.00010816311 4.846758f-5 0.00012471167 -5.6502533f-5 -9.272778f-5 -0.0001162367 -4.2943055f-5 -5.69652f-6 5.1191495f-5 -4.1729792f-5 -7.864291f-6 5.5073133f-6 -1.0749004f-5 2.1900049f-5 -6.112678f-5 -7.522458f-5 3.3274508f-5 -5.677147f-5 -0.000104613464 3.3802302f-5 6.489096f-5 -3.7505837f-5 -5.0454373f-6 -2.2892627f-5 -5.561064f-5 -6.816983f-5 5.861815f-5 0.00013188767 3.3545384f-5 4.030673f-5 0.0001473911; -0.000101338184 -3.3886314f-5 -2.7018737f-7 -0.00013459526 9.907725f-5 0.00027799967 -0.00019834834 3.5400273f-5 -1.1707409f-5 -0.00010105312 6.717375f-5 -0.00012531577 0.000108378714 0.0002952709 0.00020270947 -7.4494565f-5 2.121471f-5 3.899747f-5 -9.012236f-5 -6.014267f-5 -3.948219f-5 -0.00017769213 4.323426f-5 2.544018f-5 4.137327f-5 -8.664186f-5 -0.00010166301 6.0757815f-5 -5.338988f-5 -0.0001935943 5.171284f-5 -8.8761546f-5; -0.00013193821 -5.0528728f-5 -2.5150626f-5 -7.579001f-6 -0.0001494822 5.086211f-5 -1.2808537f-6 8.6428794f-5 -9.795758f-5 3.22309f-5 1.4057839f-6 -1.11671725f-5 -0.00015497718 -0.00019547803 4.4922508f-5 -0.0001464736 2.8142204f-5 -4.768908f-5 -6.607789f-6 8.701897f-6 -0.0001062539 0.00012417886 0.00017022823 -5.715432f-5 -4.8212637f-6 6.06542f-5 8.864419f-6 -0.00021589323 -0.00015776794 -1.0723759f-5 -6.420938f-5 0.000121206285; -0.00012318601 -2.3083112f-5 3.5363406f-5 -0.00012994303 5.6667533f-5 -0.00011043625 0.00016666073 -1.2469825f-5 0.00023012848 5.0100298f-5 -7.037115f-5 -9.69283f-5 9.7863885f-5 -7.077854f-5 2.2207014f-5 -7.530199f-5 -0.0001462962 7.594791f-5 5.3519987f-5 8.5784704f-5 1.6303344f-5 -2.448157f-5 -0.00021097953 -1.2569646f-5 -3.8745475f-5 -2.1772421f-5 -0.000104106526 -5.736898f-5 6.240032f-5 0.00014700898 0.00021905667 -3.3053432f-6; -3.9699702f-5 0.00014629257 2.9613595f-6 -7.845332f-6 -6.0175094f-5 -2.1703412f-5 5.9505128f-5 1.512364f-5 2.301177f-5 4.372031f-6 -0.00017682044 0.00013044638 0.00011109605 -0.00016561737 -5.3967695f-5 -9.039127f-5 -0.0001207593 0.00013565009 -1.9963612f-5 -3.0316914f-5 0.00014093253 0.00019957685 6.7234585f-5 4.5658067f-5 0.00017800937 -0.00019540812 0.0001532971 0.0001457061 -0.000102385064 0.00011029503 1.772489f-5 3.757657f-5; 0.00029580077 -2.23942f-6 -0.000104349216 2.0657353f-5 3.5959143f-5 -4.006515f-5 0.00028961973 1.9854604f-5 4.485329f-5 0.00016119122 -7.860031f-5 -0.00021778041 1.5497529f-6 -7.741615f-5 0.0001848435 -0.000106691405 3.3360582f-5 -0.00016602782 6.976328f-5 0.00014438029 0.00021777002 -1.9237212f-5 0.00020252886 -0.00012971004 -0.0001376687 0.00015955005 -1.1175171f-5 5.260251f-5 -8.933945f-6 -5.3035976f-5 6.5663585f-6 -2.655403f-5; 0.00016891549 -0.00012631103 0.00012700692 -4.7709902f-5 -5.6994748f-5 0.00015103392 3.4949106f-5 0.0002609988 -0.00016564893 -2.4667463f-5 2.2431017f-5 -0.0001952503 -1.68762f-5 5.1840812f-5 -0.00027584154 7.325276f-5 0.000111886126 -5.0373892f-5 -8.001074f-5 -4.4050495f-5 0.00016264377 -0.00019418943 -0.0001907584 3.0831547f-5 -4.397562f-5 2.943764f-5 -0.00010972604 -1.16889805f-5 8.9852845f-5 -2.3312445f-5 -4.1216746f-5 4.890717f-5; 0.00012636406 -6.7772016f-5 0.000101423655 -0.00014196151 6.428613f-5 -9.678904f-5 0.00015335537 -4.845562f-5 4.5940724f-5 -7.691993f-5 0.00011018601 0.00019676007 1.6643253f-5 9.1436095f-6 -4.233316f-5 -1.1008907f-5 0.00027906353 3.4919263f-5 5.561889f-5 -6.2709514f-5 7.127844f-6 7.4158335f-5 -0.0001530208 4.4403372f-5 -6.913907f-5 0.00012347588 -0.00011977481 -5.6780984f-5 7.1433635f-5 0.00014121292 2.16292f-6 5.723077f-5; -6.1018847f-5 5.9994432f-5 -2.1025055f-5 -7.374212f-5 2.011736f-5 0.00024003936 -3.641935f-5 3.6177316f-5 -0.00021528924 1.3573869f-5 -0.0001907653 9.778103f-5 -8.2250255f-5 6.0822395f-6 2.1356262f-5 0.000112431415 -0.00019132603 2.8674324f-6 -2.9779487f-5 6.30095f-5 0.00010912775 3.2433392f-5 -0.00017452707 -8.052856f-5 -3.1560005f-5 1.7208082f-5 -3.2347773f-5 -6.4674505f-6 -0.00014404189 -0.00014373382 8.27096f-5 7.270639f-5; -1.5961157f-5 -4.136214f-5 -7.2368166f-5 3.9405626f-5 -4.5812008f-6 -3.0803265f-5 6.888066f-5 -0.0003209165 9.76202f-5 -0.00019528256 -6.833917f-5 -8.949024f-5 7.3954514f-5 -5.7369067f-5 -2.9973551f-5 0.00011524067 -0.00016713163 0.00012244494 -6.6678775f-5 5.645562f-5 -8.138363f-5 0.00013005176 -4.762211f-5 -0.00021243331 -1.5640784f-5 -8.5270694f-5 7.988394f-5 -3.05983f-5 -5.9561902f-5 1.9717023f-5 -7.273644f-5 3.023156f-5], bias = Float32[0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0;;]), layer_4 = (weight = Float32[1.1417459f-6 -9.2154354f-5 6.198802f-6 7.513918f-5 -3.774318f-5 -0.00016029844 -0.00016441477 -0.00011788595 -0.000112841786 -7.3770934f-5 0.00015071794 1.6648617f-5 6.146369f-5 -3.7949227f-5 -7.868718f-5 -7.4014235f-5 -8.755276f-5 -8.561004f-8 -8.652212f-5 5.222511f-5 6.685596f-5 4.5325683f-5 -0.00021458087 -1.4983789f-5 0.00010176535 -2.1804826f-5 -0.00015040278 -2.6310427f-5 -0.00013933021 7.0674185f-5 5.6071694f-5 4.886804f-5; 0.00016711895 -0.00021441269 2.7099265f-6 -3.4005836f-5 -7.7451805f-6 7.555978f-5 0.00014134546 0.00016764224 -0.00010708229 -0.00010837667 0.00014525771 1.9983381f-5 5.928817f-5 -0.00010856197 4.6980997f-5 6.140471f-5 -6.763431f-5 1.3159464f-5 -8.220747f-5 -0.00022544306 6.247963f-5 -3.901131f-6 0.00012619093 -5.539247f-5 0.00015212843 -3.432625f-5 4.5934714f-5 -3.905411f-5 -5.4318192f-5 0.00013103426 5.619364f-7 0.00013232879], bias = Float32[0.0; 0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most DL frameworks, Lux defaults to using Float32, however, in this case we need Float64

julia
const params = ComponentArray{Float64}(ps)

const nn_model = StatefulLuxLayer(nn, st)
Lux.StatefulLuxLayer{true, Lux.Chain{@NamedTuple{layer_1::Lux.WrappedFunction{Base.Fix1{typeof(broadcast), typeof(cos)}}, layer_2::Lux.Dense{true, typeof(cos), PartialFunctions.PartialFunction{nothing, nothing, typeof(WeightInitializers.truncated_normal), Tuple{}, @NamedTuple{std::Float64}}, typeof(WeightInitializers.zeros32)}, layer_3::Lux.Dense{true, typeof(cos), PartialFunctions.PartialFunction{nothing, nothing, typeof(WeightInitializers.truncated_normal), Tuple{}, @NamedTuple{std::Float64}}, typeof(WeightInitializers.zeros32)}, layer_4::Lux.Dense{true, typeof(identity), PartialFunctions.PartialFunction{nothing, nothing, typeof(WeightInitializers.truncated_normal), Tuple{}, @NamedTuple{std::Float64}}, typeof(WeightInitializers.zeros32)}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}, layer_4::@NamedTuple{}}}(Chain(), nothing, (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()), nothing)

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(ax, tsteps, waveform; markershape=:circle, markersize=12,
        markeralpha=0.25, alpha=0.5, strokewidth=2)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(ax, tsteps, waveform_nn; markershape=:circle,
        markersize=12, markeralpha=0.25, alpha=0.5, strokewidth=2)

    axislegend(ax, [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"]; position=:lb)

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    loss = sum(abs2, waveform .- pred_waveform)
    return loss, pred_waveform
end
loss (generic function with 1 method)

Warmup the loss function

julia
loss(params)
(0.17145293657477417, [-0.02425971839980482, -0.02347514969905942, -0.022690580998313763, -0.021364959632067925, -0.019470624935930957, -0.01696743173988719, -0.013802685161898998, -0.009908193856605896, -0.005202551497418279, 0.0004091927695814442, 0.007026321025980222, 0.014731533496598203, 0.023543479380586537, 0.033288879361909024, 0.04332565947871018, 0.051888781040384635, 0.05470678906402271, 0.042740093771747215, 0.002461953307213939, -0.06574710263685611, -0.11038075246582811, -0.07684644868065406, -0.0070629911224377295, 0.03898548969707725, 0.05454220951441442, 0.05313852438067156, 0.04490879912991775, 0.034782676734843566, 0.024814893727950874, 0.015765187220477168, 0.007844326720486728, 0.0010469759455870064, -0.004711145853658598, -0.009533830854624312, -0.013521324752237508, -0.01675931411041456, -0.01932050277291237, -0.02126169371885027, -0.022626053324683832, -0.02344407818308469, -0.023734629079228613, -0.023505331346567678, -0.02275391637164664, -0.021465892089072455, -0.01961677436013426, -0.01716976122895218, -0.014075076515962927, -0.010269504814147979, -0.0056757435762165905, -0.0002040917980788089, 0.006239947440291093, 0.013739720981990115, 0.022319574044781538, 0.03184280517314327, 0.041763866325643224, 0.050567734645962904, 0.05455586488697105, 0.045721335206860225, 0.010989358507754165, -0.05416562326213272, -0.10791042225539324, -0.08701731656486468, -0.0175140510488666, 0.034368863051251315, 0.05399942084364376, 0.054290855156841296, 0.04650637087762936, 0.036323174436016945, 0.02613589903794762, 0.01684213839986292, 0.008695513396305705, 0.001711966779964922, -0.004202152316405672, -0.009144975822912946, -0.013231052543922539, -0.016543700874363156, -0.01916700326691504, -0.02115534272587337, -0.022560500398936784, -0.023412018606661648, -0.02373388211209378, -0.02353536821700239, -0.022815526588118227, -0.021564461943478608, -0.019760063450602796, -0.017365390688610254, -0.014340055846231188, -0.010617526352711975, -0.006132928409975344, -0.0007932572810951811, 0.005483729506829185, 0.012787429376233023, 0.021141400081395843, 0.030444895397524523, 0.04022845719348157, 0.049196595392421716, 0.0541465279714622, 0.048020676853060815, 0.01849198466081878, -0.04250615020860838, -0.10310202074061656, -0.09582469777658002, -0.028736787542361458, 0.02880272081730916, 0.053005086134281704, 0.05531530362038006, 0.0481086248239887, 0.037908769805926945, 0.027510356973732723, 0.01796106825372976, 0.009586209412406954, 0.0024004154921840088, -0.0036691861536389926, -0.008744874039456868, -0.01292777825003056, -0.016323905908192243, -0.019006830415436195, -0.021048024395915698, -0.02249218174233715, -0.023379819400901364, -0.02373248600540321, -0.023564152141612317, -0.022877134820711825, -0.021660899467428295, -0.019899219079396343, -0.017556554186551297, -0.014596063883329466, -0.01095495278609227, -0.0065730696789751885, -0.0013612680380835685, 0.004757498001377214, 0.011871403095158625, 0.020009087576693722, 0.029092672968437573, 0.038725381643868066, 0.0477931088500526, 0.05352416169808578, 0.04972919605466459, 0.02498141401365203, -0.031104228110583054, -0.0962253430479292, -0.10284786236553495, -0.04051461700944619, 0.022212615830933036, 0.051481264483957484, 0.05617385499697044, 0.049704216348910944, 0.0395401157672665, 0.028933113477847563, 0.019132252146573662, 0.010510835088567073, 0.0031216225645659622, -0.0031209902151816396, -0.008323721469108473, -0.012616878880753534, -0.016094333017646352, -0.018845547675665247, -0.020934633666178527, -0.022423639578307383, -0.023346214436996995, -0.023730630992484437, -0.023592839795296784, -0.022936924741829658, -0.021756482853409132, -0.020033078997773714, -0.017745033382887045, -0.014843617489943325, -0.011280517810216295, -0.006998822849953265, -0.0019098319951136547, 0.004060088821293181, 0.010992096439686747, 0.01891797325019186, 0.02778668697660715, 0.03725728058091018, 0.04637102650000353, 0.05273360064606343, 0.05092406082912004, 0.03051272500295964, -0.020248041133957752, -0.08762376636029623, -0.10774473910920694, -0.05255650383613921, 0.014555868013949834, 0.049331709123440835, 0.056826467140593695, 0.05127840859611898, 0.041210347478256665, 0.030416266315713, 0.020347359469136964, 0.011476805037362377, 0.0038694338088798843, -0.002542007081393614, -0.007891999392570292, -0.012293602156500126, -0.015859250201693566, -0.018673918326248252, -0.020821678335423404, -0.022352694268213145, -0.023311874863639347, -0.02372790510595071, -0.02362091939930721, -0.022995923923929767, -0.02184838024224056, -0.02016881307117635, -0.017925106999348312, -0.015085011324874572, -0.01159679848763753, -0.007409106433961959, -0.002437502106351013, 0.003388021132666341, 0.01014507863550597, 0.017869819605678922, 0.026526073862763025, 0.03582553470360525, 0.04494323239373693, 0.0518070727819305, 0.051682334339722996, 0.035148409346072225, -0.010128587176919033, -0.0777210813833346, -0.11027571029865865, -0.06448589064453623, 0.005804970399881447, 0.046469826792157325, 0.05721862954686421, 0.05281327617646386, 0.042920498091755056, 0.03195029784379894, 0.021615985404564515, 0.012484300745588015, 0.004651732347741648, -0.001943692870586101, -0.007440968679769551, -0.011958038945524753, -0.015614447575927379, -0.018499997686746646, -0.02070378595951876, -0.02227982650926731, -0.02327634470979769, -0.023724871850278748, -0.023648334026430115, -0.02305407116264455, -0.021940048515424384, -0.020299121096561682, -0.01810217265633723, -0.015319646772978254, -0.011903391509769243, -0.007805785030902, -0.00370817855203478])

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l, pred_waveform)
    push!(losses, l)
    @printf "Training %10s Iteration: %5d %10s Loss: %.10f\n" "" length(losses) "" l
    return false
end
callback (generic function with 1 method)

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob, BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback, maxiters=1000)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-0.00023428996792033952; -8.684048225393362e-5; -4.0255204112302425e-6; 6.547709926953261e-5; 6.415646930688764e-5; -2.7389971364800818e-5; -5.474401041272074e-5; 0.0001038101399898761; -4.85484342788303e-5; 0.00020308526291009255; -0.00019101763609787213; 8.481201803064987e-6; 7.355555862890295e-5; -4.6032477257516705e-5; 2.7787541512173035e-6; 0.00017061909602476832; -1.6627089280491422e-5; -0.00018833871581578826; 6.941935134815762e-5; 0.0001393228012603148; -5.357004556565697e-5; 4.333676042736461e-5; -5.0738530262591006e-5; 8.378465281564484e-5; 6.311031756918374e-5; 1.179143055194594e-5; 7.488556730094685e-5; 1.212683309857032e-5; 6.059751831329735e-5; 4.7191766498101334e-5; -0.00012270493607482788; -9.204918751488198e-5;;], bias = [-3.193563022897041e-16; -1.258041331171465e-16; -2.727357530961986e-18; 1.0747361564888837e-16; 2.2348920749521912e-17; 5.050767249254353e-17; -1.23421718831363e-16; 7.622118754317832e-17; -1.082368260748225e-16; 4.0903054374346425e-16; -1.0574502496256282e-16; 1.5286725339200762e-17; -3.8702082839077667e-17; -2.498879527568436e-17; 5.8033007233798516e-18; 1.3810819539343192e-16; -2.983211833823862e-17; -1.41158483881984e-16; 3.3071248618573466e-17; -3.87517711335997e-17; -6.94377152858571e-17; -3.0124129778298234e-17; -1.840825468800088e-17; 2.1101698865507897e-17; 4.6973607715703956e-17; 2.225082422349032e-17; 8.358837988388199e-18; 2.3104277654887337e-17; 1.5985515946087893e-16; 4.839950210066035e-17; -2.1756610705955587e-16; -5.653790382973844e-17;;]), layer_3 = (weight = [6.933019333883542e-5 -2.8262609947022145e-5 -4.3751604317161945e-6 1.2084573183481233e-5 1.2787497082124395e-5 4.127559598777539e-5 -6.254180904584598e-5 4.952957691053176e-5 -4.919183108363963e-5 -0.0001617006694396045 0.00011241500918184635 -0.00016421102033569454 -6.211576534774485e-5 -0.00014722682523648012 -0.00010241830593082106 -4.0512884588887425e-6 -0.0001000768736667461 3.2048488475979564e-5 -1.8917275092933042e-6 7.02788545204072e-5 -0.00010191184107320992 -6.48499101480913e-5 9.341578127709509e-5 0.00012587219837946657 -0.00018010975897418829 -0.00015517377166184568 5.976113400897631e-5 -8.72397343010173e-5 -0.0001306461256427184 0.00020178318093889646 -0.00010405321360671184 -0.00010624273115176709; -0.0001922099241049108 5.436460330449046e-5 6.813500792551913e-5 -7.43363766808057e-5 1.2577185977692743e-6 0.00017488981527341642 8.945958822798062e-5 -6.747981981542014e-5 3.1647077296582e-5 -6.175407773051974e-5 9.40209442639675e-5 -9.157660545250275e-5 -6.442035967418885e-5 2.598751368415634e-5 -6.138892652168743e-5 0.00016375429683895425 -9.859843455416786e-6 9.062281012366988e-5 0.00011687844514446383 -0.00017562851146729863 1.584451590380711e-5 -3.684102059858755e-5 -8.151572412978773e-5 1.167222896830654e-5 0.00010124160460028859 0.00011910040618451438 -7.5947405341892875e-6 -0.0001269619301886285 -3.0207413626414994e-5 0.0001471771188086839 5.596491494284626e-6 -8.401179959699524e-5; 2.0169650884415528e-5 4.228629920872744e-5 -4.9132015487327366e-5 -1.0850498551361105e-5 0.00015082783800208913 5.2321623730324825e-5 -0.00023840507340027017 -2.953716132141383e-5 4.331857569524107e-5 -6.80469514230924e-5 6.293536141638709e-5 0.0001314908164891795 7.382124902106761e-5 -9.061523382467889e-6 -0.000142428440380524 4.2661291150220426e-5 -2.4738887372574966e-5 0.00016665207083460965 -7.002969352869361e-5 4.951318048880012e-5 -8.708099478498094e-5 -0.00013093654554212431 0.00012081689935768037 1.111105727942326e-5 5.870878954385181e-6 0.00010138373334085647 -3.2110227260896435e-6 0.00016883509094951755 -4.728046617345899e-5 -0.0001977092600321954 0.00010620702385095568 -1.705761890151677e-5; 4.0653292426135544e-5 8.265282279341061e-5 -0.00013649425479356412 -2.346028329169645e-5 0.0001038355583021808 8.459299327101418e-5 -0.00011749490319756718 -0.00014663839489924833 -0.00019607676610589012 -5.55880071898176e-5 -0.0001268381713726788 -0.00011515771281678435 0.00021179199732601953 -1.761588852799972e-5 6.494370220921722e-5 -0.00012560207348552755 6.473624282976427e-5 -0.00016921975136038452 1.6619034549750086e-5 8.630507702935273e-5 7.195295647633657e-6 6.982829993817923e-5 7.498646639747579e-5 9.545140685974149e-5 0.00011361991685576908 -7.105396459350561e-5 -9.414019247755083e-5 -0.00014324558675993433 -0.00011535652108263086 5.0158999992944316e-5 -0.0001365099563100811 -0.00019545829515676342; 0.00015552796523792242 -0.00010600282419237839 -6.200173245508579e-5 -2.6891812254859386e-6 -7.593746138501856e-5 0.00011171148637311252 5.9079781043128496e-5 -2.808660772335455e-5 -0.00010279307914663226 -8.967271585477455e-5 -0.00010255219401786519 0.00011294460111764841 6.490150392290393e-5 7.700859552949524e-5 4.978451418880074e-6 7.139575447891614e-5 -6.245468992039891e-5 -0.00019520756338950758 0.00010964316445404791 -0.00021804611040042982 -4.10662958316791e-5 4.406339688447257e-6 -5.043998295714659e-6 -0.00016371490508082907 -0.00011560796135801779 -5.692604261332535e-5 -1.7543639639022553e-5 -0.00011249569959238515 -3.1838642623023936e-5 4.9648535635960985e-6 3.856525129175971e-5 -9.906687207750482e-5; 4.53321267518612e-5 -7.409978891204231e-5 0.00012179857410543122 8.391476556013065e-5 -1.2582497697202977e-5 6.966475891700506e-5 -4.563192808249367e-5 3.123007670680092e-5 -6.266765151352987e-5 -6.74517246120742e-5 1.3923494839652756e-5 -8.917090919980322e-5 9.609946980659628e-5 3.7269307063488275e-5 -9.981659561575867e-5 -0.00011647483508979936 3.5012599687851057e-5 0.00014443026564138275 -4.085216598881657e-5 -5.040346834405818e-5 -0.00013742623152624225 5.065868797415482e-5 0.0001223802287090255 4.860185479216132e-5 -0.00016008637205644797 -8.92425264505996e-5 0.0001499587565153537 -0.00017366562271903737 -2.3956981317077994e-5 -0.00018774724915105291 -4.8113859088092566e-5 1.1696494289447254e-5; 0.00011744605463689043 -0.00015397530050019568 0.00011043060540918288 9.489642135098157e-5 6.629886657695814e-5 -3.737004343771423e-5 -5.7797148920384293e-5 -5.7259022733226267e-5 0.00012026345822026736 0.00020116325932548732 -0.0001846511889109955 4.564904766970337e-5 -0.00010803074955244155 0.00012549137928521072 0.0002774790034070428 0.0002015942870545643 -5.396195164417846e-5 -1.5223530035929025e-5 6.903383356051041e-5 8.951213268093618e-5 4.9193443109958504e-5 -0.00019824929885568258 8.659164333015439e-5 -2.419730234602709e-5 -0.0001920721127046497 7.875516776924968e-5 -3.8574560030853384e-5 -0.0001085605484061128 0.00031544160329734553 3.0094490930184543e-5 7.758141028691393e-5 7.573375760084725e-6; -5.1384000547906e-6 -4.5008037985485546e-5 -0.0001195276719121722 6.48334088433089e-5 -0.00011211450428929908 -6.862970722834493e-5 -0.0003012194220977353 -7.497667785030355e-5 4.855394480871177e-5 -6.576350379729078e-5 -8.123056849528042e-7 7.970308652801367e-7 0.00021076768361269745 0.00025370246920972263 -6.747168036638699e-5 0.00010428699069196503 -0.0001330651585780543 5.258666115746146e-7 -6.65922499214742e-5 0.0002267992375814895 0.00016957151719140977 5.593378126975273e-5 0.00014086386262676827 8.101705215410635e-5 3.504585664957028e-5 -6.804499672255414e-5 -1.4536208353480105e-5 -0.00016972018157360082 -0.00017717729598046094 -2.4048448344067736e-6 -3.6535851093888035e-5 -0.00017220874825254662; -4.008738463377859e-5 2.257726597967856e-5 -0.00014946510194199556 0.00017981592688379278 -5.576388642289269e-5 7.790239621322186e-5 2.0188341721268155e-5 0.00011442135152542226 0.0001167682916894136 -0.00014905284618358901 0.0001828004810413749 0.0001413459599171353 -7.515771861285534e-5 -5.44581467074295e-5 0.000139198009917937 -2.2342096890128855e-6 3.8122353060362574e-5 4.870440395101314e-5 -4.268128898907068e-5 -7.916338066202667e-5 -7.509615673548093e-5 -1.7291582054849574e-5 -0.00011610952927362647 -3.46740903587049e-5 -7.989718281528498e-5 5.3811675086809076e-5 -7.326169026738385e-6 -7.342574966237345e-5 2.1174206693148746e-5 -2.753757647219577e-5 0.00012367165479012823 7.809772657132976e-5; 0.0001224498838534997 -0.00012651401710468642 -0.00010523929893975419 -6.914174526469138e-5 0.00011971421110205997 2.030654933764326e-5 -0.00019741209769859257 5.81149526328352e-6 -7.004040420046544e-6 -4.1955158743126846e-5 5.551092828386055e-5 2.001582571824276e-5 3.7374385207609156e-5 -0.00012075466009560322 -0.00010930506767477182 -7.219358020234302e-5 3.5430799623516673e-6 5.7321826822539473e-5 -5.255425720247729e-5 0.00015531410382822747 0.00017213537568454352 -7.145905045331735e-5 -7.90719358367408e-5 0.0001518253258079727 0.00011143404764583979 -6.617692350414856e-5 -2.9541092665799087e-5 0.00013767086704023846 0.00010330757424230401 -2.202444273575954e-5 -6.533359636091954e-5 -5.39274686457659e-6; 5.3486992250923234e-5 -1.6155569105953217e-5 -0.00016110266559310962 9.912035161202288e-6 5.9455965960902635e-5 -5.95606380423603e-5 -0.00012497042015214286 0.00013109352031703107 -0.0001327735499993042 -2.1311783789626114e-5 0.0001636456744696799 -0.00011309001100501751 6.1214197480391e-5 -6.306556847751408e-6 2.9900957010493088e-5 -3.70506841912531e-5 5.376854998273065e-5 7.124829600312237e-5 1.2124046767588446e-6 -0.00022506581270621338 -4.703029306741265e-5 -4.114653173504211e-5 1.7492823550736514e-6 6.807231867666438e-5 3.7923875347284175e-5 -5.660267750984766e-5 0.00019781534675954952 0.0001405935344470594 6.852462858175076e-5 0.00012977983161787684 6.56915646201379e-6 -6.708601548355954e-5; 1.2821012605859646e-5 -6.055714817196317e-5 -0.00010709362373748527 -0.00016187302077017184 -6.270593854426812e-5 0.00014187125947498104 -0.00018022115233182448 -0.00017303241889751498 -1.672348713330508e-5 7.830134876712311e-5 4.500247763643759e-5 0.00011680654592052373 -2.244645344037884e-5 -0.00018828224398023978 5.810648457551642e-5 3.4599076971050985e-5 7.893720925491586e-5 3.655235144832825e-5 -1.6318439839138088e-6 2.2339273928064934e-5 -0.00010465339124496579 -0.00020033111337400587 8.559174188118858e-5 1.705357084572096e-5 -1.0327165509834348e-5 0.00020142708257478008 -8.805266821584974e-5 0.00013365454964511377 -0.00012591449609320108 -0.00018896910893093384 -2.2717655665494394e-5 0.00010533844346352728; 0.00018769829861752725 -5.784814689490905e-5 1.7898393565478754e-5 8.952932498118945e-5 -5.49719281082124e-5 -0.00015340848963773634 0.00012186918851853478 -3.91278610152288e-5 -6.11762608571057e-5 2.9149403275138027e-5 -1.3387481118345836e-5 0.0002615912506769337 -4.318721599360848e-5 -0.00010212188695390562 0.00010895543331912905 -5.354204091646619e-6 -4.07845769245344e-6 -3.891111023790715e-5 9.789794461315962e-6 0.0001281417188181429 -0.00013263797350759486 -8.640433421191608e-5 -5.931427782001768e-5 -0.0001274786792748696 -0.00019096954397923206 6.915821526832132e-5 9.939634568923231e-5 -2.1454619997025475e-5 -8.660124345282908e-5 6.0501826595786544e-5 6.483552519398633e-5 0.00019949007932496278; 2.6011321835233216e-5 -8.010528762345414e-5 -3.412513541509484e-5 8.817492999848752e-5 -3.71712114322946e-5 0.0001283607132897237 9.238888816223343e-5 7.422022537109179e-5 9.470376318101646e-5 2.226187535749447e-5 -0.0001423875066287419 2.4803518328256035e-5 0.0001199202386593894 3.91897066190019e-6 1.841040358310479e-5 3.0330863238911996e-5 -4.0784422879667e-5 -0.00014598850901923816 -7.966599440653974e-5 -0.00013549376961994388 6.366019817207536e-6 -7.488950285581057e-5 -0.00012596027153297454 3.5003552340341565e-5 -1.2828753995861982e-5 2.9132811334122954e-5 9.818813924691459e-5 0.00012626245808103033 -5.230417372179364e-5 0.00013699242918998813 -3.116965598413124e-5 -1.944581634991563e-5; 6.479971778259382e-5 3.577321620337674e-5 -6.741564431781048e-5 -6.605931340748594e-6 -6.838363044284325e-5 9.116601284272881e-5 0.00014036373175127374 4.960641597659385e-5 -0.0001734989026057913 1.7967298098085817e-5 7.359704175411226e-5 -1.4849021519188371e-5 6.999791656067573e-5 -6.650613506412204e-5 -0.00012796833181841406 0.00017906581631774404 0.00020262363539455357 -2.8910952843432797e-5 -1.2437705501218162e-5 -0.00011253764825727609 -3.014878786547634e-5 9.870692445343287e-5 -5.9288348731063936e-5 9.537609165415653e-6 0.00018714659670745912 -7.239387634530761e-5 -0.00017413407915359383 -5.951010536722938e-5 6.00160790650709e-6 -9.623166699515725e-5 -1.7232109550096767e-5 -0.00016628409213939659; -0.00010035801805175012 0.00013958568907721108 -0.00017832854772540074 -0.00012101143492007665 0.00017382129609497222 7.326658688863513e-5 -0.00017349317912565646 6.532663993450195e-5 7.040244255360078e-5 1.7301694435026497e-5 -0.00016059190219354052 8.512880614172403e-5 -0.0001753309696036761 -1.3122837834529705e-5 -9.757494616210792e-6 -0.00019053051781926867 -2.2397109013113853e-5 -1.4973549503760733e-5 2.2943908526501546e-5 7.376041340787743e-5 8.951372748153172e-5 -2.7945898532237868e-6 0.00014612345713568728 -7.829740187706316e-5 3.0057155719547944e-5 -2.5270957656617076e-5 -7.821079615358228e-5 5.611527759494002e-5 -0.00014208682822616835 -2.65905980891004e-5 1.9948518649037783e-5 -2.1710127647093464e-5; -5.253349440657331e-5 9.461661470804947e-5 8.80823463851881e-5 0.00011591476145650129 0.0001857438438058167 -0.00012172842611807291 -0.00012588552267306526 3.1213308877613334e-5 3.9989841800257545e-5 0.0001361940631361697 9.888136633188038e-5 -1.1172897709632166e-5 -6.599042894567569e-5 8.00535890924094e-5 1.0384354074910917e-5 -5.222631802800042e-5 0.00011758862460947731 -0.00011568299300753869 3.275056817350646e-6 7.540119627473662e-5 -3.612989162110759e-5 -8.759408587838493e-5 6.472666068685226e-5 -9.366618537360185e-5 -9.645121094553841e-5 4.55911104402799e-5 3.876586565848222e-5 1.6737414673395996e-5 -4.9512109351503624e-5 6.726478391311141e-5 9.894393956737673e-5 8.496898594994617e-5; -4.986312628045793e-5 -3.354596914747155e-5 9.138013125023897e-5 -4.911432954633895e-6 0.0001849681037650386 -3.319202655135513e-5 -5.3711813182110835e-5 0.00010369678498804423 2.538026170199049e-5 -4.7575994488066635e-5 -4.207692031071884e-5 -3.241039589071549e-5 -2.9084036602047593e-6 9.235679369645613e-6 -0.00022806151095452627 3.278223890206872e-5 -9.384461295000169e-5 4.5506953808726195e-5 0.00011847242136112815 -5.97583085824695e-5 2.812751413997539e-5 2.890583443637388e-6 6.2490155739703e-5 -0.0001026037454200016 -0.0001772659262702859 6.1909847188267e-5 -0.00014433245373098665 -8.747334634486034e-6 -3.580176339042976e-5 7.380305630555398e-6 -0.00011178168928623072 -6.620369380712804e-5; -0.00011993300001605673 9.360706456235508e-5 -4.6767680097865734e-5 2.614551697505546e-5 9.059892175017765e-5 -2.2083761061104818e-5 -1.987066941338679e-5 8.354832964713573e-5 -0.00017450854654426772 0.00012520952434741682 -0.00012061886816055712 -6.170695275769646e-5 -0.00017167860648597687 0.00020953138294161215 0.00011374943271974957 -6.112897673767872e-6 -8.031934222991901e-6 -0.00021742684482134144 -9.396445973080675e-5 -3.115710862141244e-6 -3.979320178742608e-5 5.4395622945129024e-5 7.053794973061335e-5 -4.067545166591284e-5 -5.449786119398752e-5 0.00011538144456452613 6.805119476511048e-5 -5.2184565058006774e-5 0.0001797574367886121 -5.291885106245142e-5 0.00014115656079255847 -3.6905054068221844e-5; -7.327260344697436e-5 -3.587482887034416e-5 -9.010560414701594e-5 -9.719164471554989e-5 7.158258132456366e-5 -7.880627424514964e-6 3.059351776080594e-5 -5.857121311613966e-6 -2.7655990822880043e-5 -2.469127092570705e-5 1.3114540263461691e-5 -0.00010666235128831008 -7.981630141095572e-5 -6.556090727302222e-6 -4.7877492295288366e-5 -8.196018414982521e-5 -4.519933235080782e-5 -1.2519126801078897e-5 2.2860411729725104e-5 5.852122604901055e-5 0.0001099856022052132 8.988061064208209e-5 -0.00014783334289624942 0.0002257287856196689 -0.0001502218069503408 6.457087921496598e-5 -2.0246262908967362e-5 -6.796476326364906e-5 -7.7171439328791e-6 7.94751532573801e-6 -0.00014530615543427273 -0.0002217247087968795; -4.744837198301748e-5 -0.00020778604513400965 0.00010527242855965368 -0.00010109656133357634 6.775155530306639e-6 -3.81604301896382e-5 -0.00020990470212785883 -7.31490042330079e-5 -8.033866898989739e-5 -5.930414473767728e-5 6.540066104717567e-5 2.105899446683687e-5 0.00014021059476869118 3.4601557932146836e-5 -0.00010593348698825368 5.2424182154906105e-6 -1.0139594191461712e-5 -0.0002402862780530948 7.479283093340562e-5 5.099460330310066e-5 -9.889508939016983e-5 -2.6583726789285716e-5 -0.0001487461892924685 -4.038008111313781e-5 -6.567583906365795e-5 -0.00010702172286477646 9.579622136294806e-5 -2.5636484419412515e-5 2.278541549964068e-5 -7.216642981296858e-5 -6.686855041554943e-5 0.00015318445665395201; -6.494831616091211e-5 -0.00013368097718145983 7.076237232150959e-5 -0.0001102955400924462 1.0880238974288036e-6 -1.5874539552427853e-5 -5.372631569433063e-5 7.747191730524403e-7 -4.8518112474522946e-5 3.611561185296762e-5 -9.743495031527821e-5 -0.0001612353779123492 -2.3123779850244977e-5 -3.819873820564879e-5 -6.594883309252113e-5 -0.0001302581920967327 -0.0001323730801445756 3.0791655423466243e-5 -0.00013380261664085787 3.449157809568212e-5 -5.41177694899315e-5 1.655966405134443e-5 -3.925952917005035e-5 -8.468197195131694e-5 -5.7135611153607066e-5 -0.00016458335160085468 6.896729171469232e-5 -5.5047098452164806e-5 -6.940278837962943e-5 -7.812625526476022e-7 5.194347538830631e-5 0.0001287571096113682; 4.9290114164851764e-5 0.00010816342333661521 4.846789093668778e-5 0.00012471198502559373 -5.6502222233878674e-5 -9.272746896507021e-5 -0.00011623638633075928 -4.2942743986108326e-5 -5.696209283390136e-6 5.119180574062064e-5 -4.1729481691924676e-5 -7.863979900623957e-6 5.507624026601301e-6 -1.07486933678347e-5 2.1900359259841213e-5 -6.112647164749073e-5 -7.522427253328062e-5 3.327481832335545e-5 -5.6771159817193004e-5 -0.00010461315318146714 3.380261265071286e-5 6.489127341289259e-5 -3.7505526362413604e-5 -5.045126582748322e-6 -2.2892316648115623e-5 -5.5610328073574056e-5 -6.816951858248416e-5 5.8618459350934426e-5 0.00013188798553407359 3.35456949493773e-5 4.0307042271402245e-5 0.00014739140684347715; -0.0001013392679875497 -3.388739742389347e-5 -2.712710364814923e-7 -0.00013459634601771414 9.907616958786216e-5 0.00027799858844206033 -0.0001983494216825776 3.5399189525709435e-5 -1.170849236826993e-5 -0.0001010542032441884 6.717266705652385e-5 -0.00012531685161106023 0.00010837763011055527 0.0002952698200691526 0.0002027083905862671 -7.44956482800526e-5 2.121362603917064e-5 3.899638659038232e-5 -9.012344480892207e-5 -6.0143752764481876e-5 -3.948327457932387e-5 -0.0001776932108465549 4.323317670910978e-5 2.543909658685375e-5 4.137218684023654e-5 -8.664294682843337e-5 -0.0001016640958392946 6.075673127030295e-5 -5.339096196861541e-5 -0.00019359538373704655 5.171175552947312e-5 -8.876262973229765e-5; -0.0001319414722292145 -5.05319866551514e-5 -2.515388487333382e-5 -7.5822594244204185e-6 -0.00014948546495540047 5.085885013434697e-5 -1.2841122970402577e-6 8.642553538997447e-5 -9.796083659418098e-5 3.222764016580055e-5 1.4025252749643583e-6 -1.117043111841605e-5 -0.00015498044276861318 -0.00019548128957646164 4.4919249465265795e-5 -0.00014647685648190922 2.8138944888297305e-5 -4.769233683139477e-5 -6.6110477764266695e-6 8.698638428589385e-6 -0.00010625715995226169 0.00012417559728240376 0.00017022497295150388 -5.7157579160012776e-5 -4.824522331514132e-6 6.065094303087859e-5 8.861160584308094e-6 -0.00021589649159247458 -0.00015777119451919596 -1.07270178905008e-5 -6.42126387010211e-5 0.00012120302662700893; -0.00012318605449259965 -2.308315860724366e-5 3.5363359261771455e-5 -0.00012994307625782505 5.666748684738003e-5 -0.00011043629937781788 0.00016666068104817703 -1.2469872030544143e-5 0.0002301284373111574 5.0100251160564316e-5 -7.03711951195706e-5 -9.692834582389169e-5 9.786383880558595e-5 -7.077858326234687e-5 2.2206967246836764e-5 -7.530203883512668e-5 -0.0001462962421050744 7.594786139228601e-5 5.3519940325293556e-5 8.578465777239957e-5 1.630329707544466e-5 -2.4481615850583842e-5 -0.0002109795780547412 -1.2569692712042588e-5 -3.874552139174022e-5 -2.1772467602624774e-5 -0.00010410657277540131 -5.7369026121216104e-5 6.240027569723486e-5 0.00014700893071270717 0.0002190566271615699 -3.3053898227645233e-6; -3.969749281285554e-5 0.0001462947812422033 2.963568757654459e-6 -7.843123004248483e-6 -6.0172885188193055e-5 -2.1701202914524248e-5 5.9507337081188216e-5 1.5125849184865912e-5 2.3013978587871762e-5 4.374240287435885e-6 -0.00017681822913916606 0.00013044858956960046 0.00011109826033020951 -0.00016561516071419361 -5.396548562316206e-5 -9.038906404500117e-5 -0.00012075709213717842 0.00013565229624757792 -1.9961403214664705e-5 -3.0314705273904712e-5 0.0001409347433389053 0.000199579063572631 6.723679421952757e-5 4.566027660871366e-5 0.0001780115827011756 -0.000195405914329281 0.00015329930198163338 0.000145708309954675 -0.00010238285443960935 0.00011029723560455112 1.7727098619382376e-5 3.757877953716892e-5; 0.000295803475780203 -2.2367122192596983e-6 -0.00010434650798111527 2.066006086983038e-5 3.5961850841830364e-5 -4.006244384350624e-5 0.00028962243337167797 1.9857311742209432e-5 4.485599598132495e-5 0.0001611939234470835 -7.859759975244075e-5 -0.00021777770493195702 1.5524606459961046e-6 -7.741344492665983e-5 0.0001848462135575961 -0.00010668869694480739 3.3363289873652614e-5 -0.00016602511354275672 6.9765986297114e-5 0.00014438299800251464 0.00021777273028141263 -1.9234504036884776e-5 0.00020253156358825442 -0.0001297073311826027 -0.00013766599189972225 0.00015955275844766443 -1.1172462876004675e-5 5.2605218258524876e-5 -8.93123691918023e-6 -5.303326801646068e-5 6.569066187628816e-6 -2.6551321610045853e-5; 0.00016891432278446088 -0.00012631219760860085 0.00012700575280123337 -4.771106796451737e-5 -5.699591383799809e-5 0.0001510327547288404 3.494793963855192e-5 0.0002609976199886704 -0.00016565009900714232 -2.4668628612364417e-5 2.2429851014217647e-5 -0.00019525145944553478 -1.6877365852085295e-5 5.183964640255111e-5 -0.0002758427017402933 7.325159640405445e-5 0.00011188495955084404 -5.037505798153218e-5 -8.001190326521976e-5 -4.405166146224614e-5 0.00016264260739208257 -0.00019419059572155564 -0.0001907595597009263 3.0830380637238066e-5 -4.397678458244443e-5 2.9436474505383032e-5 -0.00010972720519984913 -1.1690146494037961e-5 8.985167914888907e-5 -2.3313610654985232e-5 -4.12179124401458e-5 4.8906005758366436e-5; 0.00012636681176775825 -6.776926338381799e-5 0.0001014264079491888 -0.000141958761647692 6.428887982445409e-5 -9.67862843902508e-5 0.0001533581261389211 -4.845286557276758e-5 4.594347674690916e-5 -7.691717937300868e-5 0.00011018876366838352 0.0001967628221825045 1.6646005773218247e-5 9.146362484610225e-6 -4.233040553691791e-5 -1.100615381469053e-5 0.0002790662864787782 3.492201622398199e-5 5.562164269911635e-5 -6.270676121524399e-5 7.130597081124857e-6 7.416108829027813e-5 -0.00015301804259826648 4.440612511068895e-5 -6.913631395611329e-5 0.0001234786349447723 -0.0001197720604034842 -5.677823091565638e-5 7.143638768291713e-5 0.00014121567337832224 2.1656730074668336e-6 5.723352196917119e-5; -6.102074644259606e-5 5.999253312508032e-5 -2.1026953844959684e-5 -7.374402070967294e-5 2.0115461439454822e-5 0.00024003745761616124 -3.642125034206439e-5 3.6175417108793073e-5 -0.00021529114042026856 1.3571969930840554e-5 -0.0001907672051395073 9.777912760069523e-5 -8.225215446986002e-5 6.080340200022857e-6 2.135436238928155e-5 0.00011242951563943027 -0.00019132793408902423 2.8655331341556343e-6 -2.978138604634222e-5 6.300760264891385e-5 0.00010912585253277892 3.243149308287576e-5 -0.00017452897031875248 -8.053045917153005e-5 -3.1561903824003594e-5 1.7206182520346958e-5 -3.234967175489024e-5 -6.4693497272331565e-6 -0.00014404378475705017 -0.00014373572071166649 8.270770172272725e-5 7.27044896400578e-5; -1.596448186196829e-5 -4.136546289256281e-5 -7.237149060412634e-5 3.940230140986993e-5 -4.584525357780453e-6 -3.080658956488835e-5 6.887733667216263e-5 -0.0003209198252431684 9.761687448306873e-5 -0.00019528587967094555 -6.834249455885133e-5 -8.949356510012727e-5 7.395118933462804e-5 -5.7372391382810674e-5 -2.9976875738355756e-5 0.00011523734479971261 -0.0001671349522791553 0.00012244161240353035 -6.668209920345606e-5 5.6452296631782384e-5 -8.138695506087518e-5 0.00013004843691422484 -4.762543488933748e-5 -0.00021243663599189867 -1.5644108715345003e-5 -8.527401900093379e-5 7.988061457698518e-5 -3.060162583889747e-5 -5.9565226764529195e-5 1.971369834908203e-5 -7.273976319875195e-5 3.0228235629048546e-5], bias = [-3.3089855218709614e-9; 5.179324200080406e-10; 8.018792987520062e-10; -2.9651984221084073e-9; -3.093897815634631e-9; -1.4365692018549015e-9; 3.200247865469278e-9; -1.0632979212011584e-9; 9.80860978170254e-10; 1.1561200683129724e-10; 6.420256499533289e-10; -1.954103966366882e-9; 6.092193462149565e-10; 5.100552449691167e-10; -4.482184919368228e-10; -1.651450007094175e-9; 2.0627815571479632e-9; -1.9746427508144914e-9; 1.0179818529850977e-10; -2.9535726074899974e-9; -3.773233804636438e-9; -4.6537238501398336e-9; 3.107367476455209e-10; -1.0836653384599934e-9; -3.258630358139372e-9; -4.6652041588962746e-11; 2.209216067754353e-9; 2.7077084656257076e-9; -1.1659727530408207e-9; 2.7529528644745285e-9; -1.899251621989272e-9; -3.32460214706526e-9;;]), layer_4 = (weight = [-0.0006641003506076221 -0.0007573967362427818 -0.0006590435702003828 -0.0005901029880453178 -0.0007029853139558628 -0.0008255407723157481 -0.0008296568689844642 -0.0007831283104266153 -0.0007780841493544111 -0.0007390133230125263 -0.0005145244383111488 -0.0006485936724745542 -0.0006037786912226114 -0.0007031916089796124 -0.0007439295660608738 -0.0007392565501084468 -0.0007527950370493198 -0.0006653278968779308 -0.0007517645115685069 -0.0006130170608561811 -0.0005983860639117017 -0.0006199161482674354 -0.0008798232548152674 -0.0006802261471907393 -0.00056347676779303 -0.0006870472153069372 -0.0008156450280495338 -0.0006915526227759168 -0.0008045725600300341 -0.0005945680078347752 -0.0006091706021814512 -0.0006163740585877442; 0.00038792791332559846 6.396362777669558e-6 0.00022351897811310854 0.00018680314747431472 0.0002130637926622387 0.00029636881731241635 0.00036215442029342284 0.00038845128649731556 0.00011372675762204614 0.00011243238777229023 0.00036606676625966717 0.00024079240539004327 0.00028009722368038903 0.00011224708154705192 0.0002677900518771902 0.0002822137420674422 0.00015317470780658947 0.00023396848759270076 0.0001386015844422546 -4.634073252371714e-6 0.0002832885662284618 0.00021690774172851933 0.0003469999908788563 0.000165416577745433 0.00037293739536530625 0.00018648280620482783 0.0002667437259808785 0.00018175488337693227 0.0001664908525303658 0.0003518432553376741 0.00022137096294746596 0.00035313774859930635], bias = [-0.0006652423891173525; 0.00022080905712146433;;]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; markershape=:circle,
        markersize=12, markeralpha=0.25, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(compute_waveform(
    dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(ax, tsteps, waveform; markershape=:circle,
        markeralpha=0.25, alpha=0.5, strokewidth=2, markersize=12)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(ax, tsteps, waveform_nn; markershape=:circle,
        markeralpha=0.25, alpha=0.5, strokewidth=2, markersize=12)

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(ax, tsteps, waveform_nn_trained; markershape=:circle,
        markeralpha=0.25, alpha=0.5, strokewidth=2, markersize=12)

    axislegend(ax, [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb)

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(LuxCUDA) && CUDA.functional(); println(); CUDA.versioninfo(); end
if @isdefined(LuxAMDGPU) && LuxAMDGPU.functional(); println(); AMDGPU.versioninfo(); end
Julia Version 1.10.2
Commit bd47eca2c8a (2024-03-01 10:14 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PROJECT = /var/lib/buildkite-agent/builds/gpuci-9/julialang/lux-dot-jl/docs
  JULIA_AMDGPU_LOGGING_ENABLED = true
  JULIA_DEBUG = Literate
  JULIA_CPU_THREADS = 2
  JULIA_NUM_THREADS = 48
  JULIA_LOAD_PATH = @:@v#.#:@stdlib
  JULIA_CUDA_HARD_MEMORY_LIMIT = 25%

CUDA runtime 12.3, artifact installation
CUDA driver 12.4
NVIDIA driver 550.54.15

CUDA libraries: 
- CUBLAS: 12.3.4
- CURAND: 10.3.4
- CUFFT: 11.0.12
- CUSOLVER: 11.5.4
- CUSPARSE: 12.2.0
- CUPTI: 21.0.0
- NVML: 12.0.0+550.54.15

Julia packages: 
- CUDA: 5.2.0
- CUDA_Driver_jll: 0.7.0+1
- CUDA_Runtime_jll: 0.11.1+0

Toolchain:
- Julia: 1.10.2
- LLVM: 15.0.7

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 25%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 3.662 GiB / 4.750 GiB available)
┌ Warning: LuxAMDGPU is loaded but the AMDGPU is not functional.
└ @ LuxAMDGPU ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/LuxAMDGPU/sGa0S/src/LuxAMDGPU.jl:19

This page was generated using Literate.jl.