Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux, ComponentArrays, LineSearches, LuxAMDGPU, LuxCUDA, OrdinaryDiffEq, Optimization,
      OptimizationOptimJL, Printf, Random, SciMLSensitivity
using CairoMakie

CUDA.allowscalar(false)

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end
one2two (generic function with 1 method)

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end
soln2orbit (generic function with 2 methods)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end
d_dt (generic function with 1 method)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0)<1e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio1 "mass_ratio must be <= 1"
    @assert mass_ratio0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit1, mass1, orbit2, mass2)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end
compute_waveform (generic function with 2 methods)

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; markershape=:circle,
        markersize=12, markeralpha=0.25, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defiing a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(Base.Fix1(broadcast, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1e-4)),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1e-4)),
    Dense(32 => 2; init_weight=truncated_normal(; std=1e-4)))
ps, st = Lux.setup(Xoshiro(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[2.1932952f-5; -0.000117923104; 3.041442f-5; 8.263941f-5; 0.00011905383; 0.00013345796; -4.5304343f-5; 1.7967443f-5; -1.6202259f-5; 3.860245f-5; -3.539703f-5; 3.9421644f-5; 5.6996367f-5; 5.8406094f-5; -0.00015446196; 4.794833f-6; -4.3468535f-5; 6.8200745f-5; -2.2279162f-5; -6.548639f-5; -4.8189664f-5; -7.4472584f-5; 3.255409f-5; -2.510247f-5; -4.6227473f-5; -0.00010157688; 1.3083046f-5; 5.8234178f-5; -3.3233628f-5; 3.9905914f-5; 0.00010956935; 0.00011959672;;], bias = Float32[0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0;;]), layer_3 = (weight = Float32[6.828196f-5 -9.608253f-5 -7.950564f-5 -8.7850945f-5 2.610508f-5 2.5082932f-6 -0.00016358828 -2.2570002f-5 -0.00014590906 0.00021655987 -0.00025545343 4.9488433f-5 2.313303f-5 -7.9431375f-5 3.8339393f-5 -8.5082385f-5 2.1679734f-5 8.661607f-5 -0.00020622245 -5.3402127f-5 -0.00012100204 0.00011002533 0.00021827524 -0.00010021072 -5.5492597f-5 -0.00018331503 -3.516101f-5 -0.00016330834 3.1100346f-5 0.00016779531 0.00011680629 1.6873306f-5; 4.6731682f-5 -2.8946335f-5 -3.55476f-6 0.00012036927 -0.00013784107 1.3775291f-5 -0.00011274148 5.912086f-5 0.0001371909 -5.793972f-5 2.7784028f-5 -4.3907545f-5 -5.314491f-6 0.000115422605 -0.00022955779 -4.071548f-5 0.00015543833 -0.00014864425 0.00012508548 -9.7616954f-5 0.00010023173 -2.3095557f-5 4.2568183f-5 -0.0002361659 -9.952895f-5 -0.00012768201 -0.00014128743 -0.00010441406 0.00010731669 4.9250466f-5 0.00010060114 -6.548325f-5; 1.7026769f-5 -0.00017071069 0.00012540152 2.69383f-5 -7.450849f-5 -8.691052f-5 3.4964192f-5 0.0001359212 -1.9878673f-6 3.7226175f-5 -2.0523608f-5 -0.00022597448 -1.09955145f-5 7.786487f-5 6.756013f-5 8.818072f-5 -0.00011478735 6.573446f-5 8.625998f-6 5.0294202f-5 0.0001889294 0.00031393673 -0.00010227276 3.5122757f-5 0.00016373939 -4.045959f-5 6.93952f-5 9.886659f-5 -9.344535f-5 0.000108031025 -4.684478f-5 0.00014776888; -0.0001622775 -4.957525f-5 -4.946854f-5 0.00010855914 4.3166194f-5 -7.419449f-5 -3.3888005f-5 0.00018963816 -4.8831327f-5 0.00014172557 8.034f-6 -6.760405f-6 -7.7915174f-5 0.0001603931 2.9404846f-5 0.00020462014 -4.3848177f-6 8.752038f-5 6.48759f-5 0.00011066481 5.081923f-5 9.957381f-5 9.712484f-7 -1.2885922f-5 -0.00023813518 -5.0657377f-6 -8.322994f-5 -5.4639833f-5 -0.00013327147 -5.4771776f-6 -0.000119533506 0.00012835847; -0.000113486 -0.0001284756 9.785238f-5 -6.379138f-5 0.00013829763 0.00018702839 -0.000107292304 -3.417045f-5 2.0823685f-5 3.6357758f-6 7.8091194f-5 -1.4760102f-5 0.0002130492 -0.00012254057 -6.8994595f-5 -0.00012366596 -0.00016616622 7.35367f-5 5.722438f-5 5.8952224f-5 -7.5200514f-5 0.00011517394 -6.8590816f-5 -4.4703585f-5 5.492295f-5 -2.2807897f-5 5.7556997f-5 -0.00012561081 5.4922115f-5 -8.651424f-5 0.00013286762 8.463793f-6; -6.0870036f-5 -3.8243557f-5 0.00023301398 -0.00021188612 1.3708941f-6 -8.714907f-5 -0.000102528655 -0.00016341762 -1.7895991f-6 -6.588599f-5 -3.5170437f-5 -3.803842f-5 -8.9261026f-5 -8.311943f-5 1.8919261f-5 9.952f-5 3.7403126f-5 -0.00015134714 0.0001234986 -0.00013733946 1.2446126f-5 0.00013490298 -4.270844f-5 0.00012657173 -8.161576f-5 -5.296665f-5 -3.9507082f-5 -9.780132f-6 -0.0002290694 3.673577f-5 -1.8567256f-5 8.6863925f-5; 0.00011639609 -4.1419058f-5 9.044503f-6 -6.0544586f-5 3.6284433f-5 -5.445015f-5 7.792655f-5 -5.8581463f-5 7.2016715f-5 3.2293083f-5 7.6365795f-5 -1.630177f-5 5.1065428f-5 -0.00010756 5.869788f-5 -0.00011277801 0.00012100285 -0.0002156505 0.00024342834 -3.938738f-5 -5.0671286f-5 8.973939f-5 -0.00011741752 5.562002f-5 -6.2715844f-5 -2.6015154f-5 1.855128f-5 -2.8971767f-6 -2.3975655f-5 -3.5153826f-5 -9.450883f-5 -0.00021314109; 4.6354962f-5 -7.833309f-5 0.00014709162 1.5411173f-5 8.63922f-5 -9.2478615f-5 0.00015349522 -8.422096f-5 -1.220617f-5 -4.608615f-7 -1.9041294f-5 -0.00020777438 -9.959633f-5 -9.658586f-5 -9.161674f-6 8.5756865f-6 0.00011714851 -9.848362f-5 9.932941f-5 9.7052645f-5 -8.260944f-5 -0.00014029315 0.00013421358 -8.7660745f-5 -2.5954376f-5 -0.00016917159 4.404065f-5 8.114559f-5 9.286826f-6 -0.00011592625 7.997241f-5 -6.693903f-5; -2.3302266f-5 0.00024232849 -0.00016017906 1.863206f-5 0.00012868083 -0.0001635048 4.3018776f-5 7.223213f-5 -0.00013465143 -0.00016405585 -6.92948f-5 1.9187087f-6 3.0029438f-5 1.2923209f-5 -0.00019480298 -7.361329f-5 -4.69209f-6 -0.00013723939 5.6953766f-5 0.00010809216 -4.5144574f-5 1.4204202f-5 -6.425016f-5 -0.00020656528 4.504947f-5 7.120651f-5 -8.619299f-5 0.0001332502 0.00013078503 -3.7498354f-5 -0.00015409775 6.12861f-5; 6.145436f-5 1.1261216f-5 -0.00010160274 3.545045f-5 -1.2634765f-5 -0.00010012833 4.319645f-5 0.00012081684 -1.7679617f-5 9.112708f-5 -1.245394f-5 -6.5528606f-5 8.984419f-5 2.7819962f-5 0.00013968893 9.847454f-5 -5.1784955f-5 -2.223882f-5 3.242368f-5 -1.6714943f-5 9.616378f-5 3.322218f-5 0.0002289037 1.726666f-5 7.0978254f-5 -8.649407f-6 -9.195196f-5 6.294235f-6 0.00013656485 -0.0002038439 3.3275571f-6 5.6359488f-5; 0.00011459213 0.00031042736 2.0532885f-5 6.671746f-5 -7.7535115f-5 8.681938f-5 -1.961993f-5 0.00018600252 2.6344067f-5 -2.1999225f-5 -0.00014214973 0.00020035885 2.8019162f-5 -8.151045f-5 -9.37591f-5 3.5998786f-5 0.00012518358 -4.4849836f-5 0.0001255852 -1.2074894f-5 0.00016039895 -5.4072403f-5 0.00022259088 4.52232f-5 -1.3419176f-5 5.3990716f-5 -3.4376397f-5 0.00015718241 0.00023110856 -2.5055198f-5 -3.2160926f-5 -0.00016630067; 4.312792f-5 -0.00011966823 3.2034513f-5 -7.0429014f-5 -6.554f-5 2.0911806f-5 1.4019724f-5 -6.0765965f-6 0.00012534454 -0.00016444872 -0.00014474629 2.8301225f-5 4.3634875f-5 0.0001058362 -9.3984825f-5 -9.008694f-6 -1.043619f-5 5.4238557f-5 3.605541f-5 9.28021f-5 -1.3163109f-6 -0.00015391121 -2.9210765f-5 0.00013773108 4.1279767f-5 -0.00011830872 1.9989096f-5 0.0001342742 4.620382f-5 -5.6314534f-6 0.00010609754 8.5475855f-5; -0.00017338913 0.00012024456 -4.6926394f-5 0.00019391577 -0.00014944645 1.45274f-5 -0.00013094975 0.00017552085 0.00016055397 -0.00012811311 0.00024523033 -5.7205692f-5 2.3733133f-5 5.6779947f-5 -0.00010001996 0.00017877141 6.612384f-5 3.3104265f-5 0.00020289379 -1.7039487f-5 6.945076f-7 3.1454165f-5 -5.5851688f-5 -0.00012213916 6.44512f-5 4.9665654f-5 7.719678f-5 7.320934f-5 -0.00018208244 -6.598534f-5 -9.665296f-5 0.000101702004; -6.336984f-5 8.554671f-5 -1.1173066f-5 -4.4387645f-5 -9.255448f-5 3.911505f-5 -0.00010030981 9.550889f-5 -7.671562f-5 -3.1705484f-5 -7.1663155f-5 -5.6050558f-5 0.00011474734 -9.2332644f-5 2.0386711f-5 -6.169128f-5 -1.9502224f-5 6.963883f-5 0.0001288985 -7.0188635f-6 4.1043244f-5 2.682793f-5 -3.456896f-5 3.9346036f-5 6.0173807f-5 -2.8084338f-5 8.4538294f-7 -3.9182945f-5 3.9024246f-5 0.00015775315 0.0001348144 7.2542876f-5; -2.6217916f-7 2.5593698f-5 -9.480516f-5 -7.521419f-5 -0.0001782697 -3.0315272f-5 -0.00017293275 -0.000140589 8.447632f-5 -0.00016678932 0.00014096429 9.6390344f-5 2.9951205f-5 -1.0377515f-5 5.5014945f-5 8.698971f-5 7.196172f-5 7.9131256f-5 -0.0001077228 -6.605536f-5 -8.6903674f-5 0.00021038605 8.688894f-5 -9.707861f-5 -2.2595674f-5 8.8311965f-5 -8.0678496f-5 9.589328f-5 -7.202394f-5 -3.619973f-6 -2.4431753f-5 9.828673f-5; -1.2021137f-5 -3.545697f-5 -3.5323184f-5 0.00013149914 -6.7836205f-5 2.8888353f-5 4.901f-5 0.00010639507 -0.00010416486 1.0609608f-5 5.0448725f-5 6.553859f-5 -8.1628685f-5 -1.4709021f-6 -5.228029f-6 7.322403f-5 -0.0001810145 6.324807f-5 -5.069121f-5 -6.6067456f-5 -1.3653491f-5 5.658326f-5 0.00012872898 0.00015667772 5.89791f-5 0.00020832811 7.895306f-5 0.00012273862 -0.000121694364 0.00018726967 -1.2963918f-5 0.00014185155; -2.997188f-5 0.00020059927 -0.00013000611 -9.4435505f-5 0.000121905134 -6.352742f-5 -7.48915f-5 2.4450426f-5 -0.00017075664 0.000154132 -0.00016410327 3.6130103f-5 0.0001070148 -4.513584f-6 0.00017129711 -5.339241f-5 0.00013056735 4.1312276f-5 2.9811225f-5 -0.00012457996 -6.218325f-5 -7.413861f-5 8.877652f-5 -0.00010787036 0.00013092699 3.7183203f-5 -2.5204634f-5 2.704599f-5 -4.6781947f-6 2.9392959f-5 -2.0126723f-5 -0.000104057886; -2.561298f-5 -0.00016316079 -0.000294215 0.00019019574 -2.3361083f-5 9.4340714f-5 0.00011490884 -0.000102808815 -3.3316403f-5 -2.0515216f-5 -9.499234f-5 -0.00018637074 -1.9079316f-5 -1.5996793f-5 -1.8345072f-5 -0.00024026776 -4.1454467f-5 -2.0442469f-5 -0.00012392623 -5.739355f-5 0.00010054955 0.000109530556 -9.369417f-6 -0.000105236504 3.0411586f-5 9.8287965f-5 1.34957045f-5 -4.650851f-5 -0.00022104719 5.1923937f-5 -0.00018413478 -5.690814f-5; -7.718606f-5 0.00015123085 9.824044f-5 -5.473278f-5 4.7632086f-5 0.00014327593 -7.379786f-5 -7.030422f-5 2.151426f-5 2.7788003f-5 0.00012239518 7.90101f-6 -0.000111765876 0.00015277413 -4.3250995f-5 -7.6757715f-5 0.00016542184 -1.3240323f-5 -3.8564274f-5 -7.3238814f-5 -6.2729196f-5 4.329657f-5 -5.2791096f-5 4.8339247f-5 3.2936056f-5 -3.126361f-5 0.0001489057 3.891509f-5 4.509945f-5 1.7030794f-6 -2.4579002f-5 6.835251f-6; 0.00021830534 1.4352694f-5 -2.3061448f-5 -0.00015720174 -0.00012114978 -2.3791224f-6 4.305542f-6 0.00018206362 5.443622f-5 0.00022683572 1.2985645f-5 -0.00010322951 0.00023371253 7.059614f-5 -4.4406156f-6 5.977707f-5 -4.2168333f-5 -1.2991627f-5 4.105969f-5 -5.2308435f-5 0.0001663021 -6.9624155f-5 0.00022551957 8.3074374f-5 3.0417845f-5 0.0001005158 -7.382778f-5 4.5255045f-5 -7.70412f-6 -5.2901734f-5 -3.5049765f-5 -7.322094f-5; -4.1813164f-5 -2.4872425f-5 -9.669556f-5 0.0001328519 -0.0001252628 0.00011794713 0.00016109922 -6.598772f-5 2.7632905f-5 7.1550516f-5 3.207835f-5 1.7853652f-5 5.4949237f-6 -0.00011105748 3.5319306f-5 -6.882455f-6 4.0229057f-5 1.5286407f-5 -0.00010705372 -7.68507f-5 -0.00011666244 -0.00021942816 -0.00011747907 9.9350305f-5 1.925278f-5 -5.6080735f-5 0.0001578313 -0.00018232467 -8.1931146f-5 -0.0002420535 1.5892974f-5 6.270408f-5; -1.4347909f-5 2.0914846f-5 0.00025928128 7.589989f-5 9.984268f-5 9.478196f-5 0.00018286315 7.749583f-5 -5.133063f-5 -7.737602f-5 -8.261936f-5 8.256726f-6 -0.00011479217 -7.801347f-5 -7.1074064f-5 -1.806771f-5 -0.00020045189 0.00022000859 -6.0813923f-5 0.00023100528 0.00013977269 7.6283824f-5 4.434742f-5 -1.2612973f-5 0.00013634372 -9.617877f-6 0.00014837008 1.0485193f-5 1.2278786f-5 -6.1550265f-5 -1.0102498f-5 0.00011543456; 4.4660166f-5 0.00014619084 -1.4833029f-5 -0.0001761829 -0.000109364504 -5.716336f-5 0.00022187052 -1.4649532f-5 -9.325449f-6 1.6507307f-5 9.8007324f-5 3.847661f-5 -5.1644744f-5 -5.8862264f-5 5.6527475f-5 0.00013051764 0.000108823355 0.00012927673 -7.391161f-5 4.0174564f-5 -9.251246f-5 0.00017323987 2.8532686f-5 -5.3090207f-5 3.5836747f-5 -5.1817693f-5 -4.5092933f-5 -6.778689f-6 -1.1796787f-5 1.2929582f-5 -0.00012898588 -2.7887028f-5; 9.0729256f-5 -2.39186f-5 8.523667f-6 -0.00015170484 -4.7555684f-5 -0.00015814522 -0.00010385778 7.0165064f-5 4.0100593f-5 -3.3261967f-6 -8.010452f-5 7.018739f-5 -5.669374f-6 -3.3802255f-5 2.4624153f-5 9.167515f-5 -1.5210256f-5 3.1066993f-5 -1.2137068f-5 -6.270515f-5 2.3934952f-8 3.6846064f-5 2.6719854f-5 0.00011416571 5.7084126f-5 -7.243903f-5 -0.00010300538 -1.5034044f-5 -4.1767773f-5 -0.00018828355 0.0002082225 -3.5861744f-5; 3.44734f-5 0.000119209486 -0.00012793558 0.00014822905 -3.9255876f-5 -8.1761704f-5 1.8210072f-5 0.00016625768 4.1925787f-6 -0.00013225539 4.0951883f-5 0.00023907526 1.9887671f-5 3.7209604f-5 -5.5916695f-5 -4.1615298f-5 1.3155305f-5 -4.8151025f-5 0.00014874179 0.00011068664 -1.4592374f-5 0.00013158527 -8.836511f-5 0.00010392407 -2.2558232f-5 3.1297503f-5 -9.2601156f-5 3.8866456f-5 0.00012158747 0.0001500725 -3.523335f-6 -0.00012515238; -0.00012616324 -3.8045578f-6 -0.0001494162 4.400954f-5 -0.00013595575 -8.879549f-5 9.534983f-5 -0.00014649333 3.351858f-5 0.00015571532 -8.557814f-5 2.04901f-5 0.0002231342 -1.3990876f-5 -2.6293243f-5 -3.423277f-5 1.911603f-5 0.0001737224 -6.597633f-5 -0.00013573434 0.00010640102 0.00010756912 -0.00016666445 -0.0001395698 7.994227f-5 -0.00017168203 0.00019140015 7.1789626f-5 -9.38633f-6 1.585067f-5 -0.00015612705 -0.00018903187; -2.1527167f-5 7.5533644f-6 -0.00020396155 -2.8633554f-5 8.736518f-5 -4.319689f-6 0.00011594017 -7.1400646f-5 2.848166f-6 -0.00020379774 7.8185614f-5 -0.00010194495 -6.817283f-5 0.00020331996 0.00014700282 -0.00010668315 -5.2586915f-7 -0.00012379134 9.567527f-5 0.00016515386 -0.00014292847 6.373327f-5 5.675032f-5 4.507613f-5 1.6977292f-5 -5.50102f-5 2.1845832f-5 2.9877114f-5 4.4903127f-6 0.00010105247 8.133764f-5 0.0001358152; -4.7549947f-5 -2.9428373f-5 6.854137f-5 7.917315f-5 4.8967362f-5 -6.903896f-5 1.7604198f-5 0.00013011182 -2.9435347f-5 0.00020266204 -3.5512934f-5 -2.9802852f-5 4.5568246f-5 7.391216f-5 3.9044015f-5 4.2615837f-5 5.3416705f-5 6.400837f-5 5.990198f-7 -9.334875f-5 2.6622769f-5 -3.8342496f-5 0.00011926965 8.206087f-5 -9.662203f-5 5.34636f-6 -6.8870184f-5 0.00013022947 6.126778f-5 6.82697f-5 -5.8824247f-5 0.00017915586; -0.00010142708 6.694781f-5 -0.00011558499 2.9768351f-5 9.473617f-5 3.0927684f-5 -7.415298f-5 5.840999f-5 4.456887f-6 3.3583066f-5 7.025574f-5 3.925173f-5 -0.00010866518 5.187756f-5 -0.00022949702 -1.8143344f-5 -1.5195293f-5 4.52401f-5 -9.3030205f-5 -2.0733505f-5 8.187751f-6 -4.7109515f-5 8.203901f-5 -1.9607198f-5 0.0001350466 -1.5382826f-5 4.6003897f-6 0.00012260016 9.774803f-6 -5.3675462f-6 -2.5980213f-5 0.0002267034; 9.549184f-5 -0.000102763384 0.00015179372 -2.627906f-5 0.00022466616 7.882539f-5 -9.2929724f-5 0.000100689125 0.00018307717 -2.1256168f-5 -7.7702796f-5 3.6802157f-5 8.480449f-5 -7.675122f-5 -9.613099f-5 6.593261f-5 -8.659869f-7 -5.858577f-5 6.1878345f-5 -1.1180192f-5 -2.7231986f-6 5.6637644f-5 -3.6389498f-5 -4.1498704f-5 -8.454687f-5 6.3402535f-5 5.6711096f-6 -7.201971f-5 -7.661604f-5 -1.6508957f-5 -5.627597f-5 0.00013630028; -5.7276728f-5 8.329872f-5 1.3173435f-5 6.767334f-5 0.00015845025 7.2797244f-5 0.00012201065 -6.529519f-5 6.503992f-5 -0.00020285822 -7.178901f-5 -6.973986f-5 3.8061065f-5 -1.2569915f-6 -3.255981f-5 -4.303023f-5 7.1661f-6 4.2873482f-5 -9.660895f-5 7.970911f-5 -0.00021904417 0.00010434085 -5.1149873f-5 8.8711175f-5 7.500384f-5 0.00015284495 9.506755f-5 -0.00010889166 -3.7006477f-5 -4.312974f-5 -0.00013272912 0.00017300504; -6.8880145f-5 -9.4610106f-5 8.75518f-5 -3.8996324f-5 0.00013291248 -3.0255114f-5 -1.9323072f-5 -2.2391489f-5 -5.7288744f-6 8.06083f-5 -4.621986f-5 -0.00016749212 9.399203f-5 -5.403094f-5 -0.00010862879 -0.00012860331 -6.111138f-5 -4.7800077f-5 -2.2286238f-5 -0.00012288599 1.8583354f-5 -1.8027695f-5 5.8970658f-5 -2.3725666f-5 -1.6811402f-5 -0.00023557413 -3.13767f-5 0.00017709202 -5.5990236f-5 0.00020122915 0.00010195658 7.5417074f-6], bias = Float32[0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0;;]), layer_4 = (weight = Float32[5.6396424f-5 3.462562f-5 0.00011002084 0.0001603618 6.782931f-5 -1.5254406f-5 5.0612347f-5 -0.00017135851 0.000114374096 1.8646437f-6 -0.00013764211 -6.4417436f-5 -0.0001612543 0.000103592894 0.0002803526 -0.000102594735 1.2642657f-5 7.359255f-5 -0.00019345696 -4.8070695f-5 0.00012133798 0.00028141082 -4.276755f-5 0.00012366612 9.3563074f-5 -0.0001365668 -1.0187378f-5 4.956689f-5 0.00016261893 -6.072639f-5 -4.6417535f-5 -5.8932936f-5; 2.1339814f-5 5.4340704f-5 0.00013476252 0.0001411146 -0.00014633793 -0.00026336292 5.4888154f-5 -9.803621f-5 -2.508312f-5 4.0102565f-5 7.9628524f-5 -0.000119004435 2.6812071f-5 8.936386f-5 8.519369f-5 -5.0069033f-5 -0.00013370572 -3.6167585f-5 0.00017016871 1.9929581f-5 -0.00010413908 9.577582f-5 0.00018869652 0.00010097966 -3.4662567f-6 8.4540436f-5 -0.00023287973 -9.659372f-5 5.390453f-6 -5.4925782f-5 6.1543986f-5 0.00013852258], bias = Float32[0.0; 0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most DL frameworks, Lux defaults to using Float32, however, in this case we need Float64

julia
const params = ComponentArray{Float64}(ps)

const nn_model = StatefulLuxLayer(nn, st)
Lux.StatefulLuxLayer{true, Lux.Chain{@NamedTuple{layer_1::Lux.WrappedFunction{Base.Fix1{typeof(broadcast), typeof(cos)}}, layer_2::Lux.Dense{true, typeof(cos), PartialFunctions.PartialFunction{nothing, nothing, typeof(WeightInitializers.truncated_normal), Tuple{}, @NamedTuple{std::Float64}}, typeof(WeightInitializers.zeros32)}, layer_3::Lux.Dense{true, typeof(cos), PartialFunctions.PartialFunction{nothing, nothing, typeof(WeightInitializers.truncated_normal), Tuple{}, @NamedTuple{std::Float64}}, typeof(WeightInitializers.zeros32)}, layer_4::Lux.Dense{true, typeof(identity), PartialFunctions.PartialFunction{nothing, nothing, typeof(WeightInitializers.truncated_normal), Tuple{}, @NamedTuple{std::Float64}}, typeof(WeightInitializers.zeros32)}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}, layer_4::@NamedTuple{}}}(Chain(), nothing, (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()), nothing)

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(ax, tsteps, waveform; markershape=:circle, markersize=12,
        markeralpha=0.25, alpha=0.5, strokewidth=2)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(ax, tsteps, waveform_nn; markershape=:circle,
        markersize=12, markeralpha=0.25, alpha=0.5, strokewidth=2)

    axislegend(ax, [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"]; position=:lb)

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    loss = sum(abs2, waveform .- pred_waveform)
    return loss, pred_waveform
end
loss (generic function with 1 method)

Warmup the loss function

julia
loss(params)
(0.1877630582570687, [-0.02428000664578264, -0.023494117791161197, -0.022708228936539458, -0.02138029758073852, -0.019482489462838433, -0.016974397034565325, -0.013802958083269111, -0.009899479156454682, -0.005181901334402667, 0.00044554730515302966, 0.007083089865121185, 0.014814268659425435, 0.023657833299152647, 0.03343794751759524, 0.043502357838898914, 0.05205564918673644, 0.05474932339684436, 0.04239595244471215, 0.0014067148617146017, -0.06710140642542807, -0.11041419383272781, -0.0755670242753345, -0.006259786139610124, 0.03892139491996867, 0.0541413271861917, 0.05277772744847532, 0.044713126140068546, 0.034756529603462025, 0.02492365815435679, 0.015968134628604642, 0.008105352955508224, 0.0013368915918147866, -0.004415029341498842, -0.009248822675718804, -0.013260492669043791, -0.0165324015404745, -0.019134658979046933, -0.0211219931790968, -0.022535865369100076, -0.0234053073546254, -0.023747848985893893, -0.023569829261340644, -0.02286761588074634, -0.021625262525580165, -0.019816524392333664, -0.017402491669913894, -0.01433073759766844, -0.010534676444438656, -0.005932680961193051, -0.00042953435630538724, 0.006076070769178174, 0.013675374030105604, 0.022400029798746685, 0.03211500062262028, 0.04225585570207433, 0.05122935499874561, 0.0551116235744414, 0.04540692096623239, 0.00858798994110583, -0.05824704136212099, -0.109013307406764, -0.08346677616419446, -0.014371228519418522, 0.03496188332442659, 0.053265059410532264, 0.053382260948599596, 0.0458823071697414, 0.03606724957876013, 0.026196594390400285, 0.0171358326479986, 0.009142429073312681, 0.002245721614243588, -0.00363374517322194, -0.008582021280950756, -0.012703933560259879, -0.016075150643666065, -0.01877389439651607, -0.020849742449877653, -0.02235081799549641, -0.023303320064159968, -0.023728594771270563, -0.023633107133858088, -0.023013440125020195, -0.02185670312519264, -0.02013762739643874, -0.01781544056377544, -0.014844884434690937, -0.011153479076060327, -0.006668572479148362, -0.0012873624139438072, 0.005084712294505299, 0.012551577597039714, 0.021151388455065406, 0.030790839406514798, 0.04097863864479155, 0.05031058716124899, 0.05525210866293535, 0.04798056523963482, 0.01524878892106066, -0.04916914245698347, -0.10624285929265836, -0.09064073816588175, -0.022853373918058063, 0.03049524208524632, 0.052094745363947516, 0.053855516268311844, 0.04700261081495643, 0.03736631045859967, 0.027477422404426723, 0.018313385158837424, 0.01019847124632429, 0.0031667589858153634, -0.0028328117752229975, -0.007903130235648757, -0.012129720286108871, -0.01560640451115671, -0.018397093537847423, -0.020565922603014777, -0.022151360004977687, -0.023189157853661654, -0.02369626806715435, -0.0236829233442666, -0.023147227753749843, -0.022074744046471678, -0.020444328296728126, -0.01821513850641808, -0.015343387925038024, -0.011757905995606794, -0.007387625078818326, -0.002129584715992911, 0.004110598603773451, 0.011441950375652161, 0.019914911919537558, 0.029465736372939578, 0.03967753815550673, 0.04931038637970392, 0.05518918134103081, 0.05014781877827933, 0.02136345186138754, -0.04001214176170138, -0.10218308176587342, -0.09691852267121288, -0.031625798918604134, 0.025504008886597977, 0.050607242107297716, 0.05418104157707443, 0.04806681738712525, 0.038652551787729666, 0.02875807139375474, 0.019508060801052916, 0.01126577489004725, 0.004107564447606359, -0.002022153615201633, -0.007203181125942962, -0.01154388342028053, -0.015120882227088359, -0.018010290692246758, -0.020265410238035495, -0.02194030199430119, -0.023061448989500036, -0.023651335312771315, -0.023720141628671598, -0.023267254298671575, -0.022280439695683083, -0.020735351108956194, -0.018602978944137098, -0.01582634638945256, -0.012346091924462404, -0.008091583940206348, -0.002956729805675072, 0.0031542390863138862, 0.010349208552470467, 0.018688567662826977, 0.0281431918047952, 0.038356622209443615, 0.048237991112088735, 0.054946444768130286, 0.0519313198559797, 0.0269332133465024, -0.030929351110500734, -0.09692694045131774, -0.10215047274715827, -0.040591163494706366, 0.019984714427610822, 0.04876955635844052, 0.054345552661398885, 0.049066470060914594, 0.03991802917980888, 0.030047243934656385, 0.02070888422914521, 0.012349235369624486, 0.005060273310352238, -0.0011875381152913682, -0.006493588443423209, -0.010942273009222203, -0.0146234555401307, -0.017604197529070654, -0.01995520787275621, -0.021715483065075595, -0.022921169743129664, -0.023592892864403518, -0.02374462241912502, -0.023374355325205536, -0.022471005505517328, -0.021017325815785345, -0.01897319856494374, -0.016295607831758513, -0.012920008012266055, -0.008778586770337115, -0.003766182613057161, 0.0022137414931603133, 0.009270879943819818, 0.017476908613387323, 0.026825156991463824, 0.037019240211727825, 0.047103667874686526, 0.05453927996831793, 0.05336236856532287, 0.031953302007735704, -0.022032786438792067, -0.09061116806136257, -0.10620813632568263, -0.04962069871372888, 0.013926201849331911, 0.04656219976908287, 0.05432803775856564, 0.04999255430200912, 0.041163073323371074, 0.03133288750918147, 0.021922076952791917, 0.01344661097197738, 0.006030836663375145, -0.00034177916293331523, -0.005766572012755247, -0.010325690246610347, -0.014110220134663646, -0.017187538322894495, -0.019629944535957147, -0.021477546438315478, -0.022767700207406334, -0.023522042195503225, -0.023755983766017147, -0.023468440132061057, -0.02264959957566193, -0.021283157744813614, -0.01933070987251701, -0.01675019781945687, -0.0134787167941833, -0.009449724555885984, -0.005420732317588492])

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l, pred_waveform)
    push!(losses, l)
    @printf "Training %10s Iteration: %5d %10s Loss: %.10f\n" "" length(losses) "" l
    return false
end
callback (generic function with 1 method)

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob, BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback, maxiters=1000)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [2.1932952222397913e-5; -0.00011792310397113319; 3.041441959789127e-5; 8.26394098111617e-5; 0.00011905383144041886; 0.0001334579574176357; -4.530434307523554e-5; 1.7967442545259487e-5; -1.6202258848345968e-5; 3.8602451240873166e-5; -3.5397031751896495e-5; 3.942164403264791e-5; 5.6996366765807925e-5; 5.840609446741921e-5; -0.0001544619590276672; 4.794832875620669e-6; -4.346853529573149e-5; 6.82007448630201e-5; -2.2279162294546533e-5; -6.54863906673842e-5; -4.818966408484318e-5; -7.4472583946717e-5; 3.255408955731187e-5; -2.5102470317448572e-5; -4.622747292154756e-5; -0.00010157687938751646; 1.308304581470695e-5; 5.8234178140860826e-5; -3.3233627618691425e-5; 3.9905913581590896e-5; 0.00010956935147991862; 0.00011959671974161207;;], bias = [1.8731274805406694e-17; -3.0873212829197617e-16; 5.960257736521858e-17; 8.433758593898219e-17; -2.6255342084821742e-17; 7.764142822338363e-17; -2.4711137832452903e-17; 9.210922607281473e-17; -7.399323597284532e-18; 4.14034616802338e-18; -3.4753467180562365e-17; 5.944284224699867e-17; 1.0470377445968028e-17; 9.221113423317791e-17; -1.9956648192452702e-17; 1.2651525988365428e-17; 4.8415842958618705e-17; 1.2643497286748033e-16; -3.822137233404277e-17; -1.3019497341656912e-16; -9.609278515067045e-17; -5.900054201368654e-17; 7.672881000901775e-17; -6.466318669791662e-17; -5.880365585463965e-17; -2.325193087205363e-16; 3.369608518691112e-18; 1.4830873685203058e-16; -7.342475732283832e-17; 1.7471911097496727e-17; -4.587623670373591e-17; 2.3323251477259974e-16;;]), layer_3 = (weight = [6.827872560375441e-5 -9.608576080511707e-5 -7.950887381316045e-5 -8.785417891789419e-5 2.6101847407581466e-5 2.5050596660320064e-6 -0.00016359151347291418 -2.2573235827572093e-5 -0.0001459122897985666 0.0002165566372523754 -0.0002554566629159528 4.948519993473029e-5 2.3129796983137622e-5 -7.943460811378983e-5 3.833615913307025e-5 -8.508561883804365e-5 2.1676500847258012e-5 8.66128363950881e-5 -0.00020622568560515456 -5.34053608761867e-5 -0.0001210052753494374 0.00011002209496933639 0.00021827200246757962 -0.00010021394998052072 -5.549583081023522e-5 -0.0001833182643794624 -3.516424409912859e-5 -0.00016331157827966474 3.109711245183971e-5 0.00016779207792847606 0.00011680305463511846 1.687007236893879e-5; 4.672949350033754e-5 -2.894852381021825e-5 -3.55694877388988e-6 0.00012036707763737389 -0.0001378432557760654 1.377310260406165e-5 -0.00011274366855677236 5.9118670345655137e-5 0.00013718871329156533 -5.7941908868356294e-5 2.7781839402182183e-5 -4.3909733482264545e-5 -5.3166795953927176e-6 0.00011542041598996225 -0.00022955997568376696 -4.071766901351676e-5 0.00015543614559983805 -0.00014864643757054417 0.00012508329515522213 -9.76191427401253e-5 0.00010022954071583226 -2.309774621270835e-5 4.256599410264661e-5 -0.00023616808770047625 -9.95311407098548e-5 -0.0001276841999572645 -0.00014128961496343993 -0.00010441624775584713 0.0001073145044426425 4.924827723125873e-5 0.00010059894835986885 -6.548543618581338e-5; 1.7030590336143716e-5 -0.0001707068639894157 0.00012540534417210348 2.6942121714883734e-5 -7.45046690076734e-5 -8.690669883212109e-5 3.4968014099356375e-5 0.0001359250255540972 -1.98404554568707e-6 3.722999654426368e-5 -2.0519786537553466e-5 -0.00022597065790025398 -1.0991692668107302e-5 7.786869000708575e-5 6.756395496490559e-5 8.818454271247921e-5 -0.00011478352634869435 6.573827895633611e-5 8.62981969701611e-6 5.029802387296004e-5 0.00018893322894880578 0.000313940554499343 -0.00010226893745995876 3.512657904364404e-5 0.0001637432088523177 -4.045576670026442e-5 6.939902423480817e-5 9.887041424233358e-5 -9.344152740370306e-5 0.00010803484665528054 -4.684095792757592e-5 0.00014777270016789836; -0.000162276314479376 -4.957406298009524e-5 -4.946735378572065e-5 0.00010856033015196637 4.316738097012186e-5 -7.419330239164639e-5 -3.3886818239543084e-5 0.00018963934536922467 -4.8830139969785435e-5 0.00014172675702619998 8.035186799398508e-6 -6.759217906355158e-6 -7.791398700893293e-5 0.0001603942885751985 2.9406033097330983e-5 0.0002046213279002508 -4.383630478250224e-6 8.75215640165056e-5 6.487708862304424e-5 0.00011066599956146966 5.08204187256437e-5 9.957499543569528e-5 9.72435548002033e-7 -1.288473475388751e-5 -0.0002381339933733302 -5.06455051682947e-6 -8.322875390395772e-5 -5.463864607255348e-5 -0.0001332702874448924 -5.475990458498257e-6 -0.0001195323190429132 0.00012835965873572185; -0.00011348606292520807 -0.00012847566338385584 9.785231870702312e-5 -6.379144704331789e-5 0.00013829756992273838 0.00018702832514756238 -0.00010729236762634633 -3.417051427894429e-5 2.0823621728791216e-5 3.6357123219340235e-6 7.809113069307833e-5 -1.47601656651978e-5 0.00021304914318826828 -0.00012254063565413618 -6.899465895620287e-5 -0.00012366602257023724 -0.00016616628797854453 7.353663588511608e-5 5.72243172639127e-5 5.895216074642576e-5 -7.520057786385683e-5 0.00011517387813293227 -6.859087968840421e-5 -4.470364890656591e-5 5.4922888214800836e-5 -2.280796050107243e-5 5.755693402826628e-5 -0.0001256108714885932 5.492205147967518e-5 -8.651430632815852e-5 0.0001328675518591037 8.463729703376182e-6; -6.087366902876333e-5 -3.824719071644336e-5 0.00023301034987297274 -0.0002118897490822791 1.3672607360645885e-6 -8.7152703529219e-5 -0.00010253228804004813 -0.00016342124847370837 -1.793232443117936e-6 -6.588962689648881e-5 -3.517406996480718e-5 -3.8042052367483545e-5 -8.9264659366117e-5 -8.312305992377265e-5 1.8915628131723743e-5 9.951636918911969e-5 3.739949268264242e-5 -0.00015135076948582325 0.00012349496417071983 -0.00013734309588699775 1.2442492404706462e-5 0.00013489934379094256 -4.2712074830444454e-5 0.00012656809583630745 -8.161939597131176e-5 -5.297028376162449e-5 -3.951071532584451e-5 -9.783765492565456e-6 -0.0002290730289337373 3.673213821229036e-5 -1.8570889480996094e-5 8.686029202863643e-5; 0.00011639489419834014 -4.142025537945481e-5 9.04330553096995e-6 -6.0545782994418005e-5 3.62832357357894e-5 -5.4451346309324714e-5 7.792534923722354e-5 -5.85826605085165e-5 7.201551814872089e-5 3.2291885495244715e-5 7.636459812131997e-5 -1.6302967574272098e-5 5.106423085165577e-5 -0.00010756119846989028 5.869668128511173e-5 -0.00011277920968056417 0.0001210016522098457 -0.00021565170428872621 0.00024342714058078153 -3.938857882091368e-5 -5.067248313495558e-5 8.973819384194435e-5 -0.00011741871677729021 5.5618823885044865e-5 -6.271704152383521e-5 -2.6016350980320128e-5 1.8550082663598135e-5 -2.89837400047552e-6 -2.39768527673436e-5 -3.5155022865335234e-5 -9.451002953497917e-5 -0.00021314228471525152; 4.635340673881938e-5 -7.833464615362591e-5 0.00014709006217923393 1.5409617568698963e-5 8.639064771738952e-5 -9.24801700454082e-5 0.00015349366157930776 -8.422251840205515e-5 -1.220772565251338e-5 -4.6241699795894595e-7 -1.904284936575756e-5 -0.00020777594027425686 -9.959788285744813e-5 -9.658741901479159e-5 -9.163229427599308e-6 8.57413096706482e-6 0.00011714695728766227 -9.848517793546796e-5 9.932785332817939e-5 9.705108977538152e-5 -8.261099565391912e-5 -0.00014029470758362988 0.0001342120246557557 -8.766230012373637e-5 -2.5955931326858964e-5 -0.000169173143425389 4.403909560675141e-5 8.114403471733013e-5 9.285270331349831e-6 -0.00011592780202012095 7.997085203566596e-5 -6.694058384130583e-5; -2.3304690759333856e-5 0.0002423260644067342 -0.0001601814824167478 1.8629634294842616e-5 0.00012867840044091317 -0.00016350722077877118 4.301635047194385e-5 7.22297022806637e-5 -0.0001346538578869619 -0.0001640582727046518 -6.929722664755063e-5 1.916283534582776e-6 3.0027012624842353e-5 1.2920784255030461e-5 -0.00019480540729236172 -7.361571849434024e-5 -4.694515129458756e-6 -0.00013724181414691693 5.695134089218167e-5 0.00010808973231941475 -4.5146999101789234e-5 1.4201776619659446e-5 -6.425258695490897e-5 -0.00020656770581955366 4.504704477621086e-5 7.120408329537622e-5 -8.619541265180643e-5 0.000133247774582188 0.0001307826073829247 -3.750077891855701e-5 -0.00015410017883514367 6.128367161036334e-5; 6.145736926327636e-5 1.1264224774532762e-5 -0.00010159972942574749 3.54534569100805e-5 -1.2631756255831556e-5 -0.00010012532209886824 4.319945957603887e-5 0.00012081984834139427 -1.767660786072526e-5 9.113009229394685e-5 -1.2450931428707125e-5 -6.55255970582452e-5 8.984719544741232e-5 2.7822970985191708e-5 0.0001396919361627357 9.84775507436063e-5 -5.1781946133365465e-5 -2.2235812032257655e-5 3.242668764641684e-5 -1.6711934115399696e-5 9.616678664087314e-5 3.3225187614780456e-5 0.0002289067016295076 1.7269669222055694e-5 7.098126308619059e-5 -8.646398675878008e-6 -9.194895017799715e-5 6.297243651254185e-6 0.00013656785824195597 -0.00020384088919172928 3.3305658393904265e-6 5.636249635925273e-5; 0.00011459794004001878 0.00031043317684920974 2.053869829358206e-5 6.672327398806756e-5 -7.752930142201069e-5 8.682519509706366e-5 -1.9614117633705784e-5 0.00018600833536145528 2.634988012976278e-5 -2.199341216221421e-5 -0.00014214391595198693 0.00020036465829723532 2.8024974738556652e-5 -8.150463727982117e-5 -9.375328991779625e-5 3.600459930856091e-5 0.0001251893914679154 -4.484402271085058e-5 0.00012559100977630575 -1.2069080555038886e-5 0.00016040476438608823 -5.406659002513433e-5 0.00022259669603333924 4.5229014716075814e-5 -1.3413362836969076e-5 5.3996529089233196e-5 -3.4370583677466004e-5 0.0001571882235959232 0.00023111436857395912 -2.504938530075818e-5 -3.215511279163058e-5 -0.00016629485654161857; 4.3128672202618144e-5 -0.00011966747426657585 3.203526701083144e-5 -7.04282605717558e-5 -6.553924633652558e-5 2.0912559718739e-5 1.4020477990304941e-5 -6.075842905085455e-6 0.00012534529066880916 -0.00016444796658350428 -0.00014474553192729398 2.830197861596198e-5 4.363562818736449e-5 0.00010583695673387175 -9.398407148427222e-5 -9.00794059151093e-6 -1.0435436181689682e-5 5.42393105000129e-5 3.6056164915242515e-5 9.280285719552665e-5 -1.315557299387644e-6 -0.00015391045910535754 -2.9210011619108602e-5 0.00013773183725544928 4.128052076095423e-5 -0.00011830796886232242 1.998985005895363e-5 0.00013427495703337392 4.62045723700776e-5 -5.630699818250022e-6 0.0001060982945794559 8.547660816231155e-5; -0.00017338683409311772 0.00012024684821093416 -4.6924102378990465e-5 0.00019391805818608522 -0.00014944416051426962 1.4529691553738848e-5 -0.00013094745389812703 0.00017552314221697907 0.00016055626457396016 -0.00012811081995883064 0.00024523262237503173 -5.7203400673326536e-5 2.373542503059671e-5 5.678223879962977e-5 -0.0001000176709464271 0.00017877370538494622 6.612612906915067e-5 3.3106556485422926e-5 0.00020289608056495934 -1.703719516504339e-5 6.967993602732945e-7 3.145645661001683e-5 -5.5849395893029934e-5 -0.0001221368658558742 6.445348827705631e-5 4.966794568767071e-5 7.719907426812745e-5 7.32116312168241e-5 -0.0001820801482505445 -6.598304770582404e-5 -9.665067062875152e-5 0.00010170429578587347; -6.33685380321295e-5 8.554800624415447e-5 -1.1171767620446445e-5 -4.4386346782171153e-5 -9.255318255008974e-5 3.911634812402543e-5 -0.00010030850935572092 9.551018920183925e-5 -7.671432504275069e-5 -3.1704185313622066e-5 -7.166185641971645e-5 -5.60492593663117e-5 0.00011474863604099175 -9.233134587838636e-5 2.038800957825251e-5 -6.168998188649481e-5 -1.9500925556251537e-5 6.964012900289706e-5 0.0001288997915239706 -7.017565077598047e-6 4.10445423754091e-5 2.6829227498188243e-5 -3.456766026055661e-5 3.934733431243134e-5 6.017510495303864e-5 -2.808303943447448e-5 8.466813355954085e-7 -3.918164693595229e-5 3.9025544173007463e-5 0.00015775444953624852 0.00013481569803707553 7.254417468948441e-5; -2.6299377646936625e-7 2.5592882881022053e-5 -9.48059748079783e-5 -7.521500782241579e-5 -0.0001782705218684048 -3.0316086562384755e-5 -0.0001729335633026588 -0.00014058982167990583 8.447550527852323e-5 -0.00016679013759495365 0.0001409634717820058 9.638952934982583e-5 2.995039023152459e-5 -1.0378329247041572e-5 5.501413062665593e-5 8.698889752485147e-5 7.196090183787384e-5 7.913044129599294e-5 -0.00010772361536875065 -6.605617789683465e-5 -8.690448856578792e-5 0.000210385231472165 8.688812551189528e-5 -9.707942780006374e-5 -2.2596488346617032e-5 8.831115002612209e-5 -8.06793101983302e-5 9.589246502945575e-5 -7.202475506711436e-5 -3.6207877055953484e-6 -2.443256715560661e-5 9.828591311451709e-5; -1.2017164940533667e-5 -3.545299894998419e-5 -3.531921227933662e-5 0.0001315031108015065 -6.783223269397998e-5 2.889232511275305e-5 4.901397134288246e-5 0.00010639904159232299 -0.00010416088533476662 1.0613579811415285e-5 5.045269737363929e-5 6.554256052821041e-5 -8.162471257039467e-5 -1.466929993533101e-6 -5.22405669930967e-6 7.322800176096445e-5 -0.00018101053191005198 6.325204541551575e-5 -5.0687238931476605e-5 -6.606348377774727e-5 -1.3649518755320485e-5 5.658723186113452e-5 0.00012873295001098118 0.00015668169309263335 5.898307373625248e-5 0.0002083320863812034 7.895703257871137e-5 0.00012274259589956001 -0.00012169039162929016 0.00018727364605636428 -1.2959946238911687e-5 0.00014185551791793977; -2.9971782800888938e-5 0.0002005993686045671 -0.0001300060154742406 -9.443540847746983e-5 0.00012190523039430031 -6.352732243135933e-5 -7.489140052235068e-5 2.4450522965022075e-5 -0.00017075654404357385 0.00015413209103393987 -0.00016410317557052087 3.6130199355507886e-5 0.00010701489582538288 -4.513487233229369e-6 0.00017129721009518566 -5.3392313652470044e-5 0.000130567447108884 4.131237283695168e-5 2.9811321205894168e-5 -0.00012457986822005682 -6.218315474575699e-5 -7.413851353226554e-5 8.877661667179007e-5 -0.00010787026048524613 0.00013092708314183784 3.7183299632711674e-5 -2.5204537173298476e-5 2.7046086238666055e-5 -4.6780980435452325e-6 2.9393055506485186e-5 -2.0126626345476796e-5 -0.0001040577896626515; -2.5618697509415634e-5 -0.00016316650526282611 -0.0002942207045635566 0.0001901900270240134 -2.336679954173695e-5 9.433499708389416e-5 0.00011490312517701081 -0.00010281453232899265 -3.332211945242214e-5 -2.0520932412252012e-5 -9.499805837556503e-5 -0.00018637646080612859 -1.908503309102608e-5 -1.6002509983147845e-5 -1.8350788923299653e-5 -0.0002402734806310759 -4.1460183456434315e-5 -2.044818556903015e-5 -0.0001239319515662632 -5.739926727989841e-5 0.00010054383365052237 0.00010952483917201949 -9.37513395918118e-6 -0.00010524222107058292 3.0405868710590612e-5 9.828224774532053e-5 1.3489987577414316e-5 -4.651422732428966e-5 -0.00022105290354382053 5.191821965120081e-5 -0.0001841405008236098 -5.6913855405628447e-5; -7.71839789176444e-5 0.00015123293060425778 9.824251692592225e-5 -5.47307011383713e-5 4.763416445896115e-5 0.0001432780097293465 -7.379578373544866e-5 -7.030213809954834e-5 2.151633971337335e-5 2.779008145992666e-5 0.00012239726284136106 7.903088500853807e-6 -0.00011176379693822104 0.0001527762048699789 -4.324891624308768e-5 -7.675563601693947e-5 0.00016542391467922363 -1.323824408061298e-5 -3.856219545738681e-5 -7.323673553216442e-5 -6.272711679955048e-5 4.3298650509339996e-5 -5.278901726157354e-5 4.834132569342933e-5 3.29381343381384e-5 -3.1261532329131145e-5 0.00014890778193333299 3.891716705950801e-5 4.510152730475228e-5 1.7051582536000988e-6 -2.4576922713073418e-5 6.8373298763817435e-6; 0.0002183093740119602 1.4356724185409444e-5 -2.3057417657901783e-5 -0.0001571977052857248 -0.00012114575003703956 -2.3750922731551605e-6 4.309572139980176e-6 0.00018206765451273004 5.4440249349363776e-5 0.00022683975037457322 1.2989675432116979e-5 -0.00010322548116287323 0.00023371656350595546 7.060017303833389e-5 -4.4365854786372035e-6 5.978110055243691e-5 -4.216430245046655e-5 -1.2987596911767513e-5 4.106372141782227e-5 -5.230440439968836e-5 0.0001663061258829178 -6.962012486961371e-5 0.00022552360240174417 8.307840396687404e-5 3.0421874888495572e-5 0.00010051982751850645 -7.382375117354019e-5 4.525907495794469e-5 -7.700090301433127e-6 -5.289770417343937e-5 -3.504573471342762e-5 -7.32169072047054e-5; -4.181609748446987e-5 -2.4875358042095655e-5 -9.669849661112392e-5 0.00013284896575679997 -0.0001252657324697567 0.00011794419568798609 0.0001610962850351027 -6.599065219451359e-5 2.7629971180076984e-5 7.154758221769198e-5 3.207541757329766e-5 1.7850718529745573e-5 5.491990202802771e-6 -0.00011106041475830926 3.531637283925642e-5 -6.885388677246586e-6 4.022612355366809e-5 1.5283473283138818e-5 -0.00010705665173408243 -7.685363508461286e-5 -0.00011666537594682035 -0.0002194310886010354 -0.00011748200032560702 9.934737189506525e-5 1.9249846980176114e-5 -5.608366828926514e-5 0.0001578283614322666 -0.00018232760468198315 -8.193407976467706e-5 -0.0002420564355019344 1.58900407656098e-5 6.270114553228723e-5; -1.4342985586017248e-5 2.0919769277613386e-5 0.00025928620522895445 7.59048103383424e-5 9.984760033251094e-5 9.478688077353604e-5 0.00018286807387189372 7.750075526123577e-5 -5.132570588907505e-5 -7.737109483864632e-5 -8.261443365777347e-5 8.261649392032808e-6 -0.00011478724848425768 -7.80085487611856e-5 -7.106914056647867e-5 -1.806278708395324e-5 -0.0002004469665088 0.00022001351096254725 -6.0808999868504146e-5 0.00023101020412572668 0.0001397776118857539 7.628874806972965e-5 4.435234365805578e-5 -1.2608048954325377e-5 0.00013634864223708752 -9.612953753154846e-6 0.00014837499981799828 1.049011695717617e-5 1.2283709630673348e-5 -6.154534133097647e-5 -1.009757436701842e-5 0.00011543948273458804; 4.4661434555559666e-5 0.0001461921079744467 -1.483176073907588e-5 -0.00017618163600946734 -0.00010936323578154 -5.716209251933214e-5 0.00022187178784133173 -1.4648263816530028e-5 -9.324180967273773e-6 1.6508575396978133e-5 9.800859207370593e-5 3.847787830363288e-5 -5.164347511982356e-5 -5.886099588042873e-5 5.652874334099038e-5 0.00013051890947007899 0.00010882462320932436 0.00012927799489897808 -7.391033922704563e-5 4.017583215226565e-5 -9.251119390046924e-5 0.00017324113524136228 2.853395436158685e-5 -5.308893868732899e-5 3.583801536205461e-5 -5.181642463231563e-5 -4.509166446314808e-5 -6.777420724580781e-6 -1.1795518178592438e-5 1.2930850613878836e-5 -0.0001289846089939697 -2.788576002150111e-5; 9.072800481949905e-5 -2.3919850478739892e-5 8.522415795556638e-6 -0.00015170608797783738 -4.7556935656791856e-5 -0.0001581464746196035 -0.00010385903418504596 7.016381301942248e-5 4.009934151686692e-5 -3.3274479720088236e-6 -8.010576833480034e-5 7.018613565738328e-5 -5.670625101668145e-6 -3.380350584488271e-5 2.462290127729521e-5 9.167390113721555e-5 -1.5211506921371616e-5 3.1065741735763394e-5 -1.2138318867113247e-5 -6.270639981816396e-5 2.2683727156163245e-8 3.684481295200243e-5 2.671860262485683e-5 0.0001141644609233083 5.708287450393702e-5 -7.24402777778227e-5 -0.00010300663392267229 -1.5035295051386605e-5 -4.176902415205726e-5 -0.0001882848007299284 0.00020822124521506453 -3.586299475140336e-5; 3.4476938536715975e-5 0.00011921302467382773 -0.00012793203967590935 0.00014823259226543967 -3.9252337218185446e-5 -8.175816509609296e-5 1.8213610757116577e-5 0.00016626122197634495 4.196117343997651e-6 -0.00013225184847102681 4.095542209255356e-5 0.00023907880205227804 1.989121011432344e-5 3.721314243306005e-5 -5.591315601922519e-5 -4.161175928744499e-5 1.3158843840970725e-5 -4.8147486439818356e-5 0.00014874532899851846 0.00011069017892735512 -1.4588835427971056e-5 0.0001315888101191661 -8.836156956819017e-5 0.000103927605606475 -2.2554692976668227e-5 3.130104124577457e-5 -9.25976175392515e-5 3.886999439713147e-5 0.00012159101134897074 0.0001500760434384904 -3.519796281377076e-6 -0.0001251488403693472; -0.0001261653802270807 -3.806695606222169e-6 -0.0001494183366798598 4.4007401281805406e-5 -0.00013595788785318743 -8.87976262515095e-5 9.534769325686811e-5 -0.00014649546809092096 3.3516441742489785e-5 0.00015571317764401588 -8.558027783006293e-5 2.0487962881402756e-5 0.00022313205537788846 -1.3993013983466115e-5 -2.629538046897999e-5 -3.423490723658303e-5 1.911389282899624e-5 0.0001737202559684627 -6.597846966765737e-5 -0.00013573648046299466 0.00010639888334257738 0.0001075669801298759 -0.0001666665908328059 -0.0001395719433116531 7.994013268903097e-5 -0.00017168416396311649 0.00019139800989937094 7.178748766595785e-5 -9.388467863077435e-6 1.5848532182231285e-5 -0.0001561291852180571 -0.00018903400503626912; -2.1525733193255466e-5 7.5547978282020015e-6 -0.00020396011672563734 -2.8632120977264478e-5 8.736661407747748e-5 -4.318255780523954e-6 0.00011594160610694509 -7.139921212218551e-5 2.8495995207885353e-6 -0.00020379630581591189 7.81870476727012e-5 -0.00010194351483950298 -6.817139359772189e-5 0.00020332138961531225 0.00014700425532064392 -0.00010668171302530752 -5.244357383971107e-7 -0.0001237899049941828 9.567670354737112e-5 0.00016515529574245177 -0.00014292704139951562 6.373470392428247e-5 5.675175276120309e-5 4.507756244275854e-5 1.697872545047263e-5 -5.500876790610455e-5 2.1847265140979253e-5 2.9878547191047184e-5 4.491746092400248e-6 0.00010105390543468388 8.133907254669425e-5 0.00013581662647775807; -4.754581213500552e-5 -2.942423752257266e-5 6.854550402341151e-5 7.917728608435834e-5 4.8971497295031505e-5 -6.903482180361358e-5 1.760833284197897e-5 0.00013011595247508782 -2.9431211527927617e-5 0.00020266617021253 -3.550879891417293e-5 -2.9798716509042942e-5 4.557238090047043e-5 7.391629579206338e-5 3.904814976110524e-5 4.2619971923621845e-5 5.342083996357008e-5 6.40125023942365e-5 6.031550178778915e-7 -9.334461837579272e-5 2.6626904132631812e-5 -3.833836062267368e-5 0.00011927378609971798 8.206500452207183e-5 -9.661789708342825e-5 5.35049521703854e-6 -6.886604869077139e-5 0.00013023360470970508 6.127191837141006e-5 6.827383431801222e-5 -5.882011218501133e-5 0.00017915999588284458; -0.00010142608303555752 6.694880934685505e-5 -0.00011558399060720124 2.9769349878395735e-5 9.473716849481714e-5 3.092868277474216e-5 -7.415198129580872e-5 5.8410989679605776e-5 4.457885863876517e-6 3.358406533391651e-5 7.025673616465624e-5 3.9252727531338725e-5 -0.00010866418384227524 5.1878558536035944e-5 -0.00022949601921698365 -1.8142345340829182e-5 -1.5194293753093087e-5 4.524109894430597e-5 -9.302920594324822e-5 -2.0732506215946105e-5 8.188749545870231e-6 -4.710851649849371e-5 8.204001127763168e-5 -1.9606198891205746e-5 0.00013504760349190278 -1.5381827013127872e-5 4.601388666406386e-6 0.0001226011612221822 9.775802339054452e-6 -5.366547300009172e-6 -2.597921380345149e-5 0.00022670439861269688; 9.54935427892885e-5 -0.00010276168479837037 0.00015179541940083018 -2.627736041691254e-5 0.00022466786019716836 7.88270887590781e-5 -9.292802435598172e-5 0.00010069082479671834 0.00018307886482429953 -2.12544685615766e-5 -7.770109671187379e-5 3.680385695983848e-5 8.480618696407512e-5 -7.674951787126163e-5 -9.612929291112458e-5 6.593431014548332e-5 -8.642873700893135e-7 -5.85840710533874e-5 6.188004469945345e-5 -1.117849235529738e-5 -2.7214990274418616e-6 5.663934341495827e-5 -3.6387798286934985e-5 -4.149700482744481e-5 -8.454516803399501e-5 6.340423412834002e-5 5.672809197083464e-6 -7.201801356646143e-5 -7.661433785474643e-5 -1.6507257283276853e-5 -5.6274271015044435e-5 0.0001363019807052654; -5.727584076387239e-5 8.329960754915323e-5 1.3174321874862971e-5 6.767422663822223e-5 0.00015845113292877735 7.27981306131098e-5 0.00012201153648627516 -6.529430621460439e-5 6.504081019509226e-5 -0.00020285733709035834 -7.17881202131362e-5 -6.973897044238097e-5 3.806195226341043e-5 -1.2561046428574996e-6 -3.255892162135909e-5 -4.302934477029718e-5 7.166987042615113e-6 4.2874368854858735e-5 -9.660806327749406e-5 7.970999930828336e-5 -0.00021904328687971872 0.00010434173517703501 -5.114898604963353e-5 8.871206145111105e-5 7.500472483068294e-5 0.00015284583704621818 9.506843988626842e-5 -0.00010889077467302307 -3.700559037545144e-5 -4.3128854404608015e-5 -0.00013272823790478282 0.0001730059294363594; -6.888210598268723e-5 -9.461206761775788e-5 8.754984039691127e-5 -3.899828559375138e-5 0.00013291051757400625 -3.0257075626333224e-5 -1.932503298454529e-5 -2.239344982959983e-5 -5.730835686029919e-6 8.060633583803546e-5 -4.622181993332821e-5 -0.00016749408357433153 9.399006696761583e-5 -5.4032902402160585e-5 -0.00010863074974044769 -0.00012860526883165562 -6.111334048333666e-5 -4.780203787382602e-5 -2.228819946473175e-5 -0.0001228879523056646 1.8581392875001175e-5 -1.8029656051712497e-5 5.896869654528656e-5 -2.372762762539722e-5 -1.68133633211821e-5 -0.0002355760920851539 -3.1378662673581505e-5 0.00017709005944774493 -5.599219773436103e-5 0.00020122719026979883 0.0001019546212104112 7.539746117205458e-6], bias = [-3.2334982694737597e-9; -2.1887320110478945e-9; 3.821789942090659e-9; 1.1871791606706534e-9; -6.346161635067834e-11; -3.6333608601794743e-9; -1.1972726205613817e-9; -1.5555120971501767e-9; -2.4251418034366624e-9; 3.00871533565388e-9; 5.8131201317333175e-9; 7.535866810189886e-10; 2.29175507163716e-9; 1.2983932794165792e-9; -8.146198624028557e-10; 3.972140786421597e-9; 9.668405791660603e-11; -5.716901868721516e-9; 2.0788338885936508e-9; 4.030133505388905e-9; -2.933495479607799e-9; 4.923614279785056e-9; 1.2683876729488507e-9; -1.2512246430888038e-9; 3.5386722068270603e-9; -2.1378422896158828e-9; 1.4334160478440102e-9; 4.135204570877709e-9; 9.989368032466207e-10; 1.699549418852619e-9; 8.868387120344241e-10; -1.9613013716759266e-9;;]), layer_4 = (weight = [-0.0006569373074955615 -0.0006787082163202139 -0.0006033128163225712 -0.0005529721009892157 -0.0006455046197029135 -0.000728588087884833 -0.0006627215530932073 -0.0008846923867626247 -0.0005989597240450843 -0.0007114691107234074 -0.0008509753622239107 -0.0007777513521949832 -0.0008745881196639682 -0.0006097410013028875 -0.00043298131500017664 -0.0008159283533488587 -0.0007006912701724659 -0.0006397407734054777 -0.0009067907968539345 -0.0007614043076636632 -0.0005919957914009068 -0.00043192268658954705 -0.0007561014467230612 -0.0005896677789546419 -0.0006197706221216691 -0.0008499006347683962 -0.0007235212671568845 -0.0006637667203960848 -0.0005507149757635576 -0.0007740602613975246 -0.0007597514472151321 -0.0007722667870124322; 0.00024758290927935606 0.0002805838321651734 0.00036100559299696774 0.00036735774165267397 7.990522629334173e-5 -3.711984000663982e-5 0.0002811313025640042 0.00012820693221436504 0.00020116000284387824 0.00026634566668057427 0.00030587146656530103 0.00010723871826106231 0.0002530551944708692 0.00031560700390740015 0.0003114368459320113 0.00017617402621792784 9.253743828100516e-5 0.00019007538116945909 0.0003964118427370727 0.00024617263840410645 0.00012210402879648242 0.0003220188390930302 0.0004149396649832605 0.00032722280905655667 0.00022277682715621244 0.00031078356334332606 -6.636582674895165e-6 0.00012964933518045782 0.00023163360420478882 0.00017131735705901898 0.0002877871376335945 0.0003647657096716287], bias = [-0.0007133339270035463; 0.00022624315666259275;;]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; markershape=:circle,
        markersize=12, markeralpha=0.25, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(compute_waveform(
    dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(ax, tsteps, waveform; markershape=:circle,
        markeralpha=0.25, alpha=0.5, strokewidth=2, markersize=12)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(ax, tsteps, waveform_nn; markershape=:circle,
        markeralpha=0.25, alpha=0.5, strokewidth=2, markersize=12)

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(ax, tsteps, waveform_nn_trained; markershape=:circle,
        markeralpha=0.25, alpha=0.5, strokewidth=2, markersize=12)

    axislegend(ax, [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb)

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(LuxCUDA) && CUDA.functional(); println(); CUDA.versioninfo(); end
if @isdefined(LuxAMDGPU) && LuxAMDGPU.functional(); println(); AMDGPU.versioninfo(); end
Julia Version 1.10.2
Commit bd47eca2c8a (2024-03-01 10:14 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PROJECT = /var/lib/buildkite-agent/builds/gpuci-2/julialang/lux-dot-jl/docs
  JULIA_AMDGPU_LOGGING_ENABLED = true
  JULIA_DEBUG = Literate
  JULIA_CPU_THREADS = 2
  JULIA_NUM_THREADS = 48
  JULIA_LOAD_PATH = @:@v#.#:@stdlib
  JULIA_CUDA_HARD_MEMORY_LIMIT = 25%

CUDA runtime 12.3, artifact installation
CUDA driver 12.4
NVIDIA driver 550.54.15

CUDA libraries: 
- CUBLAS: 12.3.4
- CURAND: 10.3.4
- CUFFT: 11.0.12
- CUSOLVER: 11.5.4
- CUSPARSE: 12.2.0
- CUPTI: 21.0.0
- NVML: 12.0.0+550.54.15

Julia packages: 
- CUDA: 5.2.0
- CUDA_Driver_jll: 0.7.0+1
- CUDA_Runtime_jll: 0.11.1+0

Toolchain:
- Julia: 1.10.2
- LLVM: 15.0.7

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 25%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 3.443 GiB / 4.750 GiB available)
┌ Warning: LuxAMDGPU is loaded but the AMDGPU is not functional.
└ @ LuxAMDGPU ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/LuxAMDGPU/sGa0S/src/LuxAMDGPU.jl:19

This page was generated using Literate.jl.