Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux, ComponentArrays, LineSearches, LuxAMDGPU, LuxCUDA, OrdinaryDiffEq, Optimization,
      OptimizationOptimJL, Printf, Random, SciMLSensitivity
using CairoMakie

CUDA.allowscalar(false)

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end
one2two (generic function with 1 method)

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end
soln2orbit (generic function with 2 methods)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end
d_dt (generic function with 1 method)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0)<1e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio1 "mass_ratio must be <= 1"
    @assert mass_ratio0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end
compute_waveform (generic function with 2 methods)

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defiing a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(Base.Fix1(broadcast, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1e-4)),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1e-4)),
    Dense(32 => 2; init_weight=truncated_normal(; std=1e-4)))
ps, st = Lux.setup(Xoshiro(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-0.00021986429; -2.4181165f-5; -8.470619f-5; 7.1279785f-5; 6.120939f-5; -1.7913088f-5; 0.00015896348; -0.00014248546; 1.21979965f-5; -7.3367526f-5; -4.7694986f-5; 0.00015659779; 0.000115183546; 2.879923f-5; -1.1795507f-5; 3.7009806f-5; -3.6008096f-5; -0.00016635376; -2.8346823f-5; 0.0001389595; 7.155961f-5; 5.0911276f-5; -1.996875f-5; -5.135688f-5; 2.0961012f-5; 3.938412f-5; 4.919901f-5; -3.9079732f-5; -0.00019027595; 4.5598048f-5; 1.8472946f-6; 0.00012070028;;], bias = Float32[0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0;;]), layer_3 = (weight = Float32[-0.00011970703 -0.00032355357 0.000101992744 9.1304784f-5 1.4774542f-5 -5.9704653f-5 -5.200492f-5 -2.4961906f-5 -9.493938f-5 -0.00015724132 4.0657167f-5 7.4435857f-6 0.0001157241 -0.0001342144 -5.364639f-5 2.9437766f-5 9.4718314f-5 -0.0001384186 -1.0920376f-5 -0.0001651323 8.016184f-5 9.7598095f-5 -2.1194555f-5 -8.837125f-5 -0.000107828826 -2.6179367f-5 -6.3006955f-5 3.614358f-5 8.1621765f-6 -6.2774765f-5 0.00012607242 8.7316876f-5; -0.0001605519 -0.00016089506 -1.5502444f-7 6.5847096f-5 -0.00013100477 3.4589153f-5 -0.0001285018 3.0210198f-5 7.7499986f-5 -1.6721518f-5 6.247313f-5 0.00022599233 -3.9805877f-6 -1.5267664f-5 9.434397f-5 2.8001992f-5 0.00010370607 0.00021680287 -2.4421763f-5 4.0057606f-5 -0.00012190551 6.253007f-6 -1.2432529f-5 3.4886933f-5 1.0571775f-5 -5.506555f-5 0.000118892225 -7.031111f-5 2.1593f-5 6.9839816f-6 0.00010122567 0.00010600677; -3.4972498f-5 0.00011025847 7.4996875f-5 5.4479246f-5 -9.548414f-5 2.4521767f-5 4.3008335f-5 -3.81695f-5 2.6905936f-5 4.4744462f-7 -9.5236035f-5 -1.2437543f-5 6.70544f-5 0.00011723649 1.2296158f-5 6.715789f-5 -4.6082034f-5 -5.571698f-5 8.4814594f-5 2.4966037f-5 5.6367735f-5 -0.00018520378 -4.3858967f-5 -0.00010102258 -4.9989096f-5 0.00024460364 -9.331217f-5 4.0637547f-5 -6.553667f-5 8.961431f-5 -6.581991f-5 6.6469904f-5; 8.9436784f-5 2.7359089f-5 4.7826306f-5 1.6676824f-5 6.644948f-5 9.299985f-6 -7.069878f-5 0.0001270878 1.533797f-5 0.00016453437 6.700469f-5 0.0001940297 0.00011680404 -1.3197277f-5 5.4674878f-5 -5.5327688f-5 -0.00012519794 3.315894f-5 3.8863152f-5 4.3123327f-5 4.9142513f-5 1.8245633f-5 -6.5871165f-5 2.4534162f-5 0.00018423329 -0.00015743532 -6.480472f-5 -0.00029396 -9.294878f-5 -2.1286214f-5 -1.3585511f-5 3.5051733f-5; 0.000112837326 -7.563444f-5 -1.9104635f-5 4.0810255f-6 -6.282293f-5 1.0062327f-5 0.00011786787 -5.00194f-5 6.970562f-6 0.00012402708 0.000121930265 2.8668046f-6 -1.1273815f-5 -0.00020992907 -9.69282f-5 6.11035f-6 0.00010475568 7.100674f-5 -5.6000652f-5 -9.806528f-5 -0.0003440598 -5.5587047f-5 -0.00011632794 0.00014384846 5.3769738f-5 -5.0716764f-5 -7.708521f-5 1.9347002f-5 4.291568f-5 -9.227676f-5 0.00020744397 -3.918885f-5; 8.2035855f-5 0.000105937565 -1.9382549f-5 9.261042f-5 1.461116f-5 -6.98951f-5 1.3479737f-5 -1.1602622f-5 -0.00010849625 0.00015370223 0.000141402 -0.00018021806 -0.00015547655 3.8662183f-6 -1.6462902f-5 0.00015287129 -5.7568068f-5 -0.00010050102 -6.354201f-5 1.999838f-5 8.052519f-5 2.387068f-5 -9.453828f-5 -6.887804f-5 3.6241192f-5 2.2187765f-5 -7.216543f-5 0.0002089625 -4.6083314f-5 9.335008f-6 0.0001661711 1.9140773f-5; -9.896355f-5 -0.00011393562 -0.00014026837 0.00014402469 0.0001263696 3.339788f-5 9.58135f-5 -8.0154736f-5 9.712953f-5 -3.0825107f-5 7.802656f-5 2.4433293f-5 5.6246445f-5 -2.4642726f-5 3.4101275f-5 -7.8197336f-5 -7.604441f-5 -9.98053f-5 -9.576849f-5 -6.622428f-5 -1.8423996f-5 3.488266f-5 -0.00015519587 3.5391684f-5 -2.1922224f-5 -2.53511f-5 5.1339768f-5 -7.2795665f-5 -0.00016455965 6.584748f-5 -9.662017f-5 5.8526042f-5; -0.0001425721 -5.0824943f-5 -8.786429f-5 8.1413746f-5 -0.0001276605 -3.7145885f-5 -8.72819f-5 -9.859716f-6 0.0001043537 7.5034266f-5 2.8030272f-5 0.00017544297 -0.00013811844 -0.0001519299 7.783843f-5 1.0102858f-5 8.3686195f-5 6.749608f-6 2.4640785f-5 -0.00014541291 1.5982145f-5 0.00013564734 -3.0732634f-5 -4.7531263f-5 0.00037124305 -7.954508f-5 -8.1259794f-5 0.00015950829 0.00024116538 -7.145377f-5 -6.449087f-5 0.00010523666; -7.0623944f-5 1.3346076f-5 2.0430585f-5 -6.571638f-5 9.3115406f-5 2.0306341f-5 4.4755366f-6 0.000121214885 -0.00012945442 0.00016604454 -0.00023726054 -9.967462f-5 9.1398004f-5 -3.6116802f-5 -1.8464129f-6 -2.4610863f-5 -0.000134255 -0.00017700886 -1.1803601f-5 -7.138554f-6 5.4827044f-5 -0.00015283664 0.00022035728 -7.399659f-6 4.037378f-6 -3.359489f-5 5.6685218f-5 2.572607f-5 3.330527f-5 0.00010200961 -0.00012307886 9.9864126f-5; 0.0002256101 -0.00015355629 7.1782626f-5 -3.635245f-5 -2.1414453f-6 2.6295902f-5 -9.074421f-5 -0.00015311434 3.5926376f-5 0.000156137 0.00010558868 -2.812004f-5 -7.251379f-5 4.0655643f-5 5.6096243f-5 0.00015046819 -7.6313285f-5 -7.348414f-5 0.00010241055 8.955638f-5 0.000112746806 -8.057876f-5 -6.390426f-5 7.7377896f-5 0.00014596565 4.4616292f-5 2.6240894f-5 -2.216276f-5 0.00014764581 -3.0207571f-5 0.00017875302 0.00028719142; 8.262174f-5 0.00023824963 -0.00014883629 -0.00012542668 -0.00014107458 6.3629646f-5 -1.4881959f-5 6.2133535f-5 4.2149524f-5 -0.00021962536 -0.0001224595 -0.00011656336 3.6519552f-5 7.869158f-5 -1.7966942f-5 1.5735606f-5 -6.250552f-6 -0.00011641836 -9.824867f-5 4.1683586f-5 5.6012625f-5 2.2972665f-5 0.00018613617 -7.001371f-5 -3.9800983f-5 -0.00010276166 -8.686349f-5 0.000117190124 -5.7744193f-5 4.4595927f-5 4.92096f-5 -0.000154023; 0.00015357607 -1.9273359f-5 -0.0002458435 6.508857f-5 -5.0794344f-5 0.00014306932 -7.434883f-5 -3.6830123f-5 -0.00014685706 1.5345242f-5 -8.025687f-5 0.000106712476 -7.693994f-5 3.8232803f-5 -0.000121897974 -7.787598f-5 0.0001441025 -9.490241f-6 -1.2814366f-5 0.000102755585 -6.032581f-5 0.00013220524 0.00025423607 0.00018353968 5.2394294f-6 -5.382615f-5 -0.00014838656 -4.512666f-5 8.264418f-5 -0.00019726143 -7.890754f-5 6.399465f-5; 5.7079378f-5 0.00013295544 5.2303876f-5 0.00015576526 -4.505605f-5 0.00026099337 -6.142928f-5 4.184877f-6 0.000162935 -8.164639f-5 7.204224f-5 0.00011680522 -0.00018416124 -9.1776186f-5 7.9492485f-5 0.00017195176 -7.0144444f-5 -0.00010023847 -4.776693f-5 4.2917593f-5 8.988766f-5 -3.0140787f-5 7.246558f-5 0.00019953889 8.093004f-5 -2.0849611f-5 -2.8352072f-6 2.0447045f-5 -1.6330279f-6 -0.00016161398 2.6088866f-5 5.9331047f-5; -0.0001522644 -2.9498053f-6 -5.1970823f-5 -0.000120694924 7.535382f-5 9.017289f-5 -5.2897627f-5 6.058088f-5 1.7722856f-5 2.2227336f-5 -7.6952194f-5 3.463898f-5 -4.354354f-6 -0.00022866469 8.8956294f-5 -5.4946056f-5 -8.506325f-5 -7.0886104f-5 4.9743147f-5 5.296115f-5 7.383785f-5 1.9270486f-7 0.00012138066 -9.244539f-5 -1.5079453f-5 -4.6771198f-5 -0.0001079332 -0.0001010174 -2.7456708f-5 -0.00015107062 8.106751f-5 0.00012254753; 0.00012430515 6.8306085f-6 -0.00010506123 -1.5918651f-5 0.0002369587 -0.000105464234 -7.5329786f-5 1.0541065f-5 -8.458628f-5 -1.5393742f-5 0.00010512236 8.270782f-5 -3.7875972f-5 0.00015805321 7.015058f-5 0.00011300398 2.8047973f-5 -6.256494f-5 -0.000112150505 2.142213f-5 -0.00018489556 -0.00018711564 1.1976383f-5 0.000119453696 -1.8142768f-5 -9.049327f-5 0.00016819665 4.3249233f-6 1.3386598f-5 2.1709302f-6 9.3085655f-5 0.00011920633; 5.729451f-5 0.00020788421 -4.3494827f-5 0.00018928846 -2.8482831f-5 2.2937753f-5 8.0990336f-5 0.000124084 -1.2563773f-5 8.4655854f-5 -0.00010047057 6.3497166f-5 8.008529f-5 -5.4528327f-6 -2.0323561f-5 -0.000106067855 8.0588805f-5 -7.26889f-5 -0.00010266684 -3.0146026f-5 3.4432924f-5 3.4495042f-5 -0.00010024183 5.2479925f-5 -1.3460092f-5 -6.676351f-5 -0.00011265794 -0.0001101304 -7.387956f-5 -2.999257f-5 7.330146f-5 -6.828447f-5; -0.00021625523 0.00013715313 -5.0959326f-5 -3.132915f-5 -0.0001700314 7.963608f-5 -9.120197f-5 -5.7185247f-5 -6.0241488f-5 -3.517341f-5 -2.9812974f-5 0.0001315818 -9.541739f-5 0.00013281558 -6.645306f-5 0.0001515339 0.0001229446 -0.00026193523 -2.5097273f-5 -0.00013624712 -8.738891f-5 -4.8950522f-5 -1.1370713f-5 5.7023855f-5 5.651799f-5 -7.4539685f-6 -8.125414f-5 -9.654005f-5 -4.0716324f-5 9.961917f-5 -5.9844646f-5 -2.5756464f-5; 0.000103122584 -5.254926f-5 6.280847f-5 -0.00019250665 -3.5651126f-5 0.00010736441 0.00013069955 0.000111548456 9.577019f-5 -1.9755016f-5 -0.00012561801 -3.4977802f-5 -2.001005f-7 -9.342732f-5 7.402082f-5 -3.4641787f-6 2.2644448f-5 -6.789425f-5 -0.00012510609 2.3021219f-5 -2.4138844f-6 -0.00011886836 2.8005396f-5 -2.472407f-5 6.873307f-5 -8.779668f-5 0.0001975219 0.00016496111 -0.000103693084 5.2645668f-5 -2.6648404f-5 7.8784404f-5; -2.2276896f-5 -7.4065785f-5 -7.625459f-5 -0.000121762 6.536258f-5 -7.027805f-5 1.9348678f-5 9.2483155f-5 5.6458528f-5 1.11882355f-5 5.2028186f-5 -0.00012087703 0.00015171338 6.8886075f-5 0.0001548027 0.000115709445 0.000112052265 0.00014013208 -1.6478878f-5 0.00011483098 -0.00013496862 0.00015306943 6.351451f-5 -2.0377725f-5 -4.2687476f-5 -5.2078312f-6 4.3810727f-5 7.4438874f-5 -4.7925998f-5 7.501983f-5 -1.2696952f-5 4.214785f-6; 0.00013781762 -5.0243634f-5 0.00014110848 6.799904f-5 0.00013717425 1.811924f-5 2.4977269f-6 0.00013493249 0.000119322736 -2.3513999f-6 9.8731936f-5 -2.989994f-5 -0.00021029133 -3.7052857f-8 0.00012230666 5.7951347f-5 4.0268325f-5 5.0470207f-5 -1.587506f-5 1.4790493f-5 -1.4898094f-5 -0.00030924208 0.00016364669 6.294498f-5 2.6629425f-5 7.332333f-6 -6.668801f-5 -4.5836943f-5 3.6606318f-5 -3.3921046f-5 -5.7379617f-5 -1.3284139f-6; -9.779987f-5 -6.041408f-5 -0.00012765385 -0.00012286683 0.000104263374 0.00012344745 -2.9422803f-5 -0.00012357537 -2.4235736f-5 -0.00025470922 1.6557835f-5 0.00012967594 6.352435f-5 -4.6323406f-5 -2.6217915f-5 3.4795998f-5 3.0885258f-5 0.000104919396 0.0001288842 3.100607f-5 2.9420276f-5 -2.3630431f-5 -4.1386356f-5 5.4171847f-5 9.482797f-5 1.33753365f-5 -2.9676221f-5 7.970082f-5 -9.134384f-5 -0.00013081674 2.8468507f-6 3.8243572f-5; -2.2925476f-6 0.00014101223 -2.9348095f-5 -2.2968768f-5 -0.000113014845 -0.00019485771 4.8800663f-5 8.71154f-5 -6.9511276f-5 9.315522f-5 3.4609853f-5 0.000116031566 -4.100345f-6 -7.462584f-5 0.0002924409 -0.00022787035 8.82813f-5 -2.4439496f-5 -0.00022373817 0.00015438831 0.000105726795 6.029131f-5 -0.00021323183 7.5739456f-7 -0.00014027416 -6.800228f-5 0.00014831005 -5.097668f-5 -7.704775f-5 -2.7391352f-5 -4.8337268f-5 -2.0462607f-5; 6.8061054f-5 0.00011450267 -0.00010563719 -0.00010747289 -0.00018790684 6.702332f-5 7.657032f-5 -1.47335495f-5 -0.0001327249 -0.00010151479 -5.350014f-5 -4.1763983f-6 8.4770865f-5 1.0884004f-5 1.652895f-5 -0.0001103781 -2.0224248f-5 9.288841f-5 2.2296439f-5 1.0437467f-5 5.6513363f-5 -0.00010572844 0.0001891634 0.00011986841 -7.7725126f-5 -8.996066f-5 -7.742269f-5 -0.00014895688 -0.00010975902 2.166479f-5 0.00019009855 8.812207f-5; -1.0573368f-5 7.252006f-5 8.895753f-5 -1.5280542f-5 -0.00026071697 0.00022608692 -0.00023659687 -0.00023000267 -6.485892f-5 0.000114435956 -0.00012215914 -4.1294377f-5 1.3903856f-5 -4.7348043f-5 3.589925f-5 3.9392933f-5 5.0811854f-5 -2.0453774f-5 0.000111414396 -0.00011160603 0.0001241526 -4.0450737f-5 -4.005589f-5 8.679368f-5 0.00010031366 7.4826465f-5 -2.1860384f-5 -4.369396f-5 -0.00016589274 0.00020271809 -2.5046256f-5 0.0001307953; 0.00022860515 -5.6197467f-5 -1.4616204f-5 -7.6996475f-5 -0.00026992525 -4.7766312f-5 2.8560917f-5 3.0272351f-5 -2.360277f-6 -0.00023704766 3.7478396f-5 9.779773f-5 -5.6107645f-5 -6.522516f-5 0.0001499253 6.8457455f-5 7.601139f-5 -0.00011968743 -7.562342f-5 0.00016417711 0.00016028952 -0.000100713434 5.777053f-5 -4.590717f-5 -0.00010004554 0.00013894115 -4.0977375f-5 -0.00026432623 1.8358898f-5 -0.00023622182 -0.00012384522 7.123681f-5; -0.00013992794 -7.269973f-5 -0.000105029096 -0.00020517051 0.00019564143 0.00010208752 0.00016527213 -4.6664467f-5 -0.00017078227 0.000121247875 -2.0600007f-6 0.00010900713 5.168722f-5 -6.7127257f-6 -4.3825494f-5 -9.639971f-5 -2.6894457f-5 -0.00022076932 -0.00010975378 -7.7859266f-5 1.3137432f-5 3.4782537f-5 -8.169942f-5 5.2081755f-6 -1.8572486f-5 0.0001637624 -4.6630168f-5 -3.9384818f-6 -3.389552f-5 -0.00013646728 5.4795714f-6 2.6424355f-5; -5.2433377f-5 -0.0001110018 5.5358927f-5 -7.7152574f-5 -5.202054f-5 0.00020800519 9.317272f-5 0.00016705439 9.245014f-5 4.0937422f-5 -7.875313f-5 -8.09254f-5 -6.213016f-5 -8.302398f-5 3.6213016f-5 -4.348941f-6 -2.5291813f-5 3.878223f-5 -7.315126f-5 -9.228824f-5 -0.00015429931 -4.941953f-5 -0.00011058081 8.4873885f-5 0.00014192714 4.585271f-5 8.3924526f-5 -9.404668f-6 0.0001034626 -0.00013766959 -0.000105060695 -9.1103524f-5; -0.00013296049 0.000121672056 -0.0001199812 -0.00017570025 0.00017074219 -5.778934f-6 -0.00011839917 5.7921236f-5 7.7079436f-5 0.000112308524 0.0001293142 -2.4613388f-5 1.5306028f-5 1.2436248f-5 -6.642914f-5 7.133881f-6 6.134977f-5 -6.766797f-5 -2.2893717f-5 -3.487971f-5 -7.7706456f-5 0.00010025689 -3.114529f-5 -7.137389f-5 -3.0843374f-5 -2.3316697f-5 5.275603f-5 1.4345207f-5 0.00012294597 -5.9981696f-5 -8.078266f-5 2.3206512f-5; -0.00011633337 -0.00010606313 0.0001271389 6.751047f-5 7.0736474f-5 -0.00023915838 0.00013134563 -9.773075f-5 -9.952802f-5 3.8191465f-5 -2.2884959f-5 -6.578526f-5 -2.6989088f-5 -0.00012622781 -1.0960236f-5 1.2591512f-5 -0.0001480231 -0.00011232845 2.5271172f-5 9.113491f-5 6.473698f-5 -0.00019960631 0.000165029 0.000101524965 1.1067974f-6 3.196999f-5 4.9142815f-5 -5.837631f-5 -8.621179f-5 4.4939152f-5 -3.8861315f-5 8.62582f-5; -4.4957003f-5 2.304612f-5 0.000114514354 3.592797f-5 0.00016507923 -0.00012255555 -6.723204f-5 -0.00010632834 -7.435364f-5 5.0218023f-5 2.6975607f-5 -8.7020206f-5 -5.4197004f-5 -4.750189f-5 -3.564283f-5 -9.222104f-5 -1.0648973f-5 -0.00012119442 9.3295246f-5 -6.7426394f-5 -0.00013777592 9.2129616f-5 -4.2424213f-5 6.820344f-5 -8.577842f-7 4.5854802f-5 -2.0011763f-5 2.0584692f-5 0.00011201219 -9.964079f-5 4.4063767f-5 0.00011549698; 3.3283573f-5 2.7137228f-6 0.000102871716 -1.2390793f-5 7.685554f-6 4.9038485f-6 -4.2749007f-6 -5.5969143f-5 -4.024605f-5 -5.064937f-5 3.7872378f-5 -6.2782936f-5 -0.00016820489 -0.00015452981 3.1356337f-6 -9.944102f-5 -4.3236327f-5 -2.7915736f-5 3.9454153f-5 4.8719103f-5 -7.98698f-6 0.000166839 0.00012160293 9.059506f-5 0.00011591627 -1.4767141f-5 -2.5857005f-5 0.00015263686 -0.0001947737 -0.00011853977 4.743893f-5 4.841553f-5; 0.00015316457 -1.5828506f-5 -1.5589469f-5 -0.00018694728 0.00012063567 0.00014692437 -0.0001978313 0.00023188679 -4.914752f-5 0.00011536652 -7.0370384f-5 -2.2241318f-5 -0.00010129425 -0.00011334474 5.698331f-6 -7.616173f-5 -5.5640307f-5 -0.0001410393 -3.185859f-5 -0.00011036851 -0.00010894888 -8.473684f-5 -3.3007236f-5 5.521409f-6 -1.7030514f-5 -0.00014570355 -0.000100087 -0.00013691174 -0.0001926801 0.00010582357 3.8979364f-5 0.00020689695], bias = Float32[0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0;;]), layer_4 = (weight = Float32[3.4458058f-6 0.0001786903 -0.00010543029 0.00017717486 -5.478466f-5 1.8827444f-5 5.0352774f-5 0.00011531756 1.4993323f-5 -1.0428618f-5 -1.0160871f-5 0.00015586319 -0.00015523293 0.00015465448 7.855496f-5 6.804422f-5 0.00016364617 1.3091831f-5 1.581715f-6 0.00014978822 0.00013716338 3.8247414f-5 9.50873f-5 -3.4004996f-5 -5.3482836f-5 3.2206714f-5 -0.00010660397 -0.00021703773 1.5807282f-5 -0.00013275901 0.0002052564 7.7989615f-5; 0.00011446834 -0.00012881022 8.706643f-5 -0.00012640229 4.8734262f-5 -5.738212f-5 0.00017840462 5.420183f-5 1.3589673f-5 9.848639f-5 0.00013749456 9.746136f-5 0.0001093063 -1.09208f-6 3.48478f-5 -0.00017908297 -0.00015626762 0.00018652079 -4.6589583f-5 -1.1652096f-5 -2.8026108f-5 -8.1809405f-5 7.6267697f-6 -8.9926914f-5 0.00011439855 1.6931139f-5 -5.381971f-5 1.8379533f-5 2.2593467f-5 2.9163308f-5 3.5035862f-6 2.6733635f-5], bias = Float32[0.0; 0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most DL frameworks, Lux defaults to using Float32, however, in this case we need Float64

julia
const params = ComponentArray{Float64}(ps)

const nn_model = StatefulLuxLayer(nn, st)
Lux.StatefulLuxLayer{true, Lux.Chain{@NamedTuple{layer_1::Lux.WrappedFunction{Base.Fix1{typeof(broadcast), typeof(cos)}}, layer_2::Lux.Dense{true, typeof(cos), PartialFunctions.PartialFunction{nothing, nothing, typeof(WeightInitializers.truncated_normal), Tuple{}, @NamedTuple{std::Float64}}, typeof(WeightInitializers.zeros32)}, layer_3::Lux.Dense{true, typeof(cos), PartialFunctions.PartialFunction{nothing, nothing, typeof(WeightInitializers.truncated_normal), Tuple{}, @NamedTuple{std::Float64}}, typeof(WeightInitializers.zeros32)}, layer_4::Lux.Dense{true, typeof(identity), PartialFunctions.PartialFunction{nothing, nothing, typeof(WeightInitializers.truncated_normal), Tuple{}, @NamedTuple{std::Float64}}, typeof(WeightInitializers.zeros32)}}}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}, layer_4::@NamedTuple{}}}(Chain(), nothing, (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()), nothing)

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)

    axislegend(ax, [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"]; position=:lb)

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    loss = sum(abs2, waveform .- pred_waveform)
    return loss, pred_waveform
end
loss (generic function with 1 method)

Warmup the loss function

julia
loss(params)
(0.19334622800474166, [-0.024293821549663693, -0.023507048162518262, -0.022720274775372733, -0.021390831014647487, -0.019490823124787646, -0.01697975560866544, -0.013804446902230203, -0.009896046634256871, -0.005172302940155367, 0.0004627801686267466, 0.0071096464093933385, 0.014851953660109374, 0.02370820022185066, 0.03350123041605052, 0.04357467713850848, 0.05212204384669369, 0.05476831214415505, 0.04227693639989014, 0.0010384965866437933, -0.0675927314518358, -0.11048685852769646, -0.07513244006572291, -0.005857087670966944, 0.03907881191494629, 0.05414478755767037, 0.05272592096459218, 0.04465389534727326, 0.03470674283208557, 0.024886837035828356, 0.015942946090355185, 0.00808909907711154, 0.0013268166271728408, -0.004421304725060299, -0.009253204271710112, -0.013264437040127449, -0.016536969531617093, -0.01914057370132563, -0.021129685127670862, -0.022545505190349525, -0.02341682492619792, -0.023760938710138332, -0.023583942473678704, -0.022881941190694494, -0.021638686958764836, -0.019827583055748218, -0.017409295352323786, -0.014330877060933807, -0.010525108158519557, -0.005909592829355096, -0.00038821881289455653, 0.0061412580881653415, 0.01377080780746639, 0.02253184066165866, 0.032286182168017895, 0.04245843418927985, 0.051424768703749395, 0.05518666017609473, 0.04510532708105395, 0.0075656780594411084, -0.059726155108962786, -0.10938147097057824, -0.0822212790704335, -0.013093397963333704, 0.03549889523179043, 0.053301268788831846, 0.0532267859438894, 0.04569169725547241, 0.035899676499046525, 0.02606688601664774, 0.017042541137893597, 0.009079038594746949, 0.0022049621114875865, -0.0036584113981694266, -0.008596011414082623, -0.012711534095613828, -0.016079658892382352, -0.018777754781217654, -0.020854685857680626, -0.022357952809747894, -0.023313210474882582, -0.023741289595206962, -0.023648162933459647, -0.023029890538188165, -0.021873027017929238, -0.02015166295545278, -0.017824268123395298, -0.01484469060904796, -0.011139339891915413, -0.006634245028793924, -0.0012250429379049718, 0.005184508369117935, 0.012699749558014171, 0.021358896738695517, 0.03106438721358143, 0.04130896050550679, 0.05064255956382946, 0.055416631842018255, 0.04758032490513654, 0.013695716020002665, -0.05164756743131878, -0.10718734806142606, -0.08875272859387671, -0.020615671607541495, 0.03152325009597398, 0.05222037247904038, 0.05361512608055689, 0.04668250138635166, 0.037076790124999585, 0.02724923007539701, 0.018146670060657108, 0.010083473688762046, 0.0030917931769639585, -0.0028785271773114464, -0.007928636816674394, -0.012142278521191537, -0.015611660132769808, -0.018399327992774144, -0.02056825335793303, -0.022155924560550558, -0.0231972202074091, -0.023708312197548112, -0.023698666249044303, -0.02316563067219954, -0.022093937207312186, -0.020461535784938558, -0.018226501260287374, -0.015343785239554514, -0.01174065098136249, -0.007344181171815618, -0.0020491854978215806, 0.0042411309908648885, 0.0116379867045766, 0.020192398033914538, 0.029835849318991813, 0.040132034085454474, 0.049783671819199585, 0.055470137833203384, 0.04972284169371124, 0.01940412871102195, -0.0434562234360991, -0.1039560377297244, -0.09460693994688671, -0.0283703591413611, 0.027140920663161476, 0.050887792191699734, 0.053879301884412485, 0.047621061271660076, 0.03823762456111925, 0.028425920411405375, 0.019262545867156727, 0.011094568270989574, 0.003994713463733555, -0.0020917044288990176, -0.0072422398218730455, -0.011562797297533312, -0.015127777311245658, -0.018011384606365907, -0.02026531992350128, -0.021942259042491637, -0.023067507693735802, -0.023662468553739855, -0.023736319599907418, -0.023287408268476247, -0.02230244047497662, -0.020755868783247335, -0.018617327177529026, -0.015828168634553826, -0.012327071364301188, -0.008041014588314226, -0.0028610348279097964, 0.003311797778784306, 0.010588379976819618, 0.019030358363799484, 0.028603889485840334, 0.03893086221850597, 0.04885481339058799, 0.055365203047921495, 0.0515452286796966, 0.024689630587400868, -0.03526668908103007, -0.09974097993097736, -0.0996782094772404, -0.03629442234066166, 0.02235205015532673, 0.0492802431084593, 0.054011255199313365, 0.0485012279262307, 0.039375040280646256, 0.029605905977220476, 0.020379133301800544, 0.012117080638024615, 0.00490572456599037, -0.0012838627027825583, -0.006548349855466752, -0.010969043027883478, -0.014632961545616503, -0.01760471883212108, -0.019952925548791477, -0.02171483505681524, -0.02292506149126192, -0.023602883176537363, -0.023760962338396682, -0.023396045074951792, -0.022495708390988082, -0.021041268499942967, -0.018990895979767584, -0.016299614422181488, -0.012900465705418118, -0.008722757841912074, -0.003657826847664596, 0.002394781764067234, 0.009548596224483173, 0.017877436358362187, 0.027370294566587083, 0.03770810656245143, 0.04786422297577901, 0.05511190795588265, 0.05306950606257021, 0.029540756054919286, -0.027159839437757904, -0.09462890918657486, -0.10387370424327268, -0.04430105752107611, 0.017146743077219275, 0.04738752339023777, 0.05399602776687352, 0.04931661355544649, 0.040490376728126574, 0.030777364606893474, 0.021502649888207413, 0.013148662585224522, 0.005830626054034542, -0.0004679502304197868, -0.005839309645944943, -0.010361928434274428, -0.014123399716546924, -0.017188107205282088, -0.01962576337110008, -0.02147433181037581, -0.022769295576417065, -0.023530647771653427, -0.023772221558414734, -0.023491432515454116, -0.022676881349239043, -0.021310565733175458, -0.019352077944799578, -0.016757060334976374, -0.013459797481710466, -0.009390381988722265, -0.005320966495733892])

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l, pred_waveform)
    push!(losses, l)
    @printf "Training %10s Iteration: %5d %10s Loss: %.10f\n" "" length(losses) "" l
    return false
end
callback (generic function with 1 method)

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob, BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback, maxiters=1000)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-0.00021986429055645165; -2.4181164917533732e-5; -8.470618922720676e-5; 7.127978460619064e-5; 6.120939360696846e-5; -1.7913087503984652e-5; 0.0001589634775883623; -0.00014248545630837146; 1.2197996511507e-5; -7.336752605615734e-5; -4.7694986278582596e-5; 0.0001565977872814801; 0.00011518354585868696; 2.879922976713469e-5; -1.1795506907204811e-5; 3.7009805964697993e-5; -3.6008095776202085e-5; -0.00016635375504804072; -2.834682345562621e-5; 0.00013895949814447247; 7.155961066017259e-5; 5.0911276048297056e-5; -1.9968749256803508e-5; -5.135688115838689e-5; 2.0961011614380127e-5; 3.938411828130642e-5; 4.919901039096515e-5; -3.907973223244552e-5; -0.0001902759540823056; 4.5598048018282546e-5; 1.8472945839663445e-6; 0.00012070027878500935;;], bias = [-4.693822811142114e-16; -2.4802839072389227e-18; -9.284668254221724e-17; 1.2766467041830115e-16; 2.103162573070336e-17; 6.57414948397178e-20; -1.1605048823605374e-16; 2.3156010871559968e-17; 2.7529170614732866e-17; -9.717646930681633e-17; -8.339235040603794e-17; 1.9949684474116876e-17; -9.646925005964886e-17; 4.334859765350181e-17; -1.6543583543206164e-17; 8.244127922773964e-17; 6.331857630864361e-18; -4.88790083677719e-16; -2.5904791627436382e-17; 2.4036821749950068e-16; 1.7623512205922983e-17; -1.5946606225523637e-17; -1.3230564787666013e-18; -4.390187893126288e-17; 4.174144358871672e-17; -2.016581355761714e-18; 6.239924782247401e-17; -5.44777247610732e-17; -3.35346592720933e-16; 1.0135245975682824e-17; 1.4114091325752443e-18; 7.889225292032797e-17;;]), layer_3 = (weight = [-0.00011971057561870463 -0.0003235571146984371 0.00010198919893757544 9.13012392078792e-5 1.4770997167415327e-5 -5.970819789634774e-5 -5.2008466682812614e-5 -2.4965451175884907e-5 -9.494292486177534e-5 -0.0001572448617121867 4.065362203565895e-5 7.440040589650285e-6 0.00011572055349409456 -0.00013421794015865466 -5.3649933634537825e-5 2.943422090485591e-5 9.471476927448422e-5 -0.00013842214853916637 -1.0923920727313296e-5 -0.00016513584147156572 8.015829202069768e-5 9.759454964243381e-5 -2.119810038546779e-5 -8.837479423216965e-5 -0.00010783237108679231 -2.6182911697636223e-5 -6.301050039890667e-5 3.614003629938151e-5 8.158631444039867e-6 -6.277831003952324e-5 0.0001260688787983006 8.73133304953019e-5; -0.00016054959382163692 -0.00016089274253460148 -1.5271181241128286e-7 6.584940889154656e-5 -0.00013100245381935305 3.459146569951885e-5 -0.0001284994807443443 3.021251064494146e-5 7.750229884407944e-5 -1.6719205852103637e-5 6.247544548244519e-5 0.00022599464751548647 -3.978275109670842e-6 -1.5265351286475163e-5 9.434628624006635e-5 2.800430480277536e-5 0.00010370838459351475 0.00021680518580834658 -2.4419450020618336e-5 4.005991900923629e-5 -0.00012190319943471721 6.255319711666557e-6 -1.2430216194447359e-5 3.488924517881402e-5 1.0574088044846171e-5 -5.506323690639451e-5 0.00011889453777115152 -7.030879463997313e-5 2.1595312099094648e-5 6.986294245902368e-6 0.00010122798153901013 0.00010600908604685539; -3.497149048817189e-5 0.00011025947930430861 7.499788222740985e-5 5.448025291808598e-5 -9.548313067388029e-5 2.452277437224858e-5 4.300934194143846e-5 -3.816849169415275e-5 2.690694379235963e-5 4.484519468976749e-7 -9.523502779519375e-5 -1.2436535537193917e-5 6.705541051597713e-5 0.00011723749373014875 1.2297165601571284e-5 6.715889646112759e-5 -4.6081026477973644e-5 -5.571597324196344e-5 8.481560104666138e-5 2.4967044344164368e-5 5.63687422687625e-5 -0.00018520277483947777 -4.38579594923662e-5 -0.0001010215750578811 -4.9988088387531604e-5 0.0002446046451656131 -9.33111627385971e-5 4.063855482637715e-5 -6.553566020775031e-5 8.96153174540941e-5 -6.581890596170092e-5 6.647091101295651e-5; 8.943892340326621e-5 2.7361227928571997e-5 4.7828445765562134e-5 1.6678963454697582e-5 6.645161943896887e-5 9.302124688196776e-6 -7.069663819704541e-5 0.00012708993768472234 1.5340109037146867e-5 0.00016453651210912606 6.700682593663538e-5 0.00019403183274123534 0.00011680617922837405 -1.3195138073848128e-5 5.467701726747208e-5 -5.532554836673284e-5 -0.00012519580172220792 3.3161079279738126e-5 3.8865291806081754e-5 4.3125466852580016e-5 4.9144651946054104e-5 1.824777269366026e-5 -6.586902576810236e-5 2.4536301005075805e-5 0.00018423542520185445 -0.0001574331833965574 -6.480258138463527e-5 -0.00029395786963714844 -9.294664027250528e-5 -2.1284074812053204e-5 -1.3583371348709268e-5 3.505387235935247e-5; 0.00011283593364126686 -7.563583033560441e-5 -1.9106027075837554e-5 4.079633135472685e-6 -6.282432416854878e-5 1.006093504612191e-5 0.00011786647980399952 -5.0020792451032256e-5 6.96916965308448e-6 0.00012402568706376097 0.00012192887249467376 2.8654122338795606e-6 -1.1275207662477842e-5 -0.00020993046467127863 -9.692958968161291e-5 6.108957740717074e-6 0.00010475428741257075 7.100534737175963e-5 -5.6002044339483884e-5 -9.806666906159117e-5 -0.00034406120078573396 -5.5588438890926014e-5 -0.00011632932994915778 0.0001438470690750466 5.376834548725142e-5 -5.071815660857218e-5 -7.708660535216844e-5 1.934560993459918e-5 4.291428655917845e-5 -9.227815001439831e-5 0.0002074425766018157 -3.919024214199707e-5; 8.203716348150176e-5 0.00010593887341901256 -1.938124009239706e-5 9.261172741179689e-5 1.4612468619536805e-5 -6.989378820645085e-5 1.3481046296585736e-5 -1.1601313274072786e-5 -0.00010849494068889334 0.00015370354154373844 0.00014140330236631904 -0.00018021675427218456 -0.00015547523875494498 3.86752722778874e-6 -1.6461592933688446e-5 0.00015287259808036992 -5.75667589530951e-5 -0.00010049971100881387 -6.354069850413544e-5 1.999968950047821e-5 8.052650005783288e-5 2.387198870355239e-5 -9.453696915653e-5 -6.887673302326842e-5 3.6242500898601225e-5 2.2189074259307437e-5 -7.216411981272063e-5 0.00020896381389110837 -4.608200546710805e-5 9.336317050559421e-6 0.00016617240831482774 1.9142081471443443e-5; -9.896596847743251e-5 -0.00011393803387826801 -0.00014027078677380432 0.0001440222685277262 0.0001263671864807395 3.33954621100095e-5 9.581108216018937e-5 -8.015715238350314e-5 9.71271137176774e-5 -3.082752400140392e-5 7.80241452093014e-5 2.44308766058832e-5 5.6244028114560476e-5 -2.464514281663585e-5 3.4098858036402926e-5 -7.819975243827759e-5 -7.604682568407124e-5 -9.980771673787606e-5 -9.577090631418405e-5 -6.622669852283732e-5 -1.8426412766609757e-5 3.4880244952464865e-5 -0.00015519828693422066 3.538926730917436e-5 -2.1924640436820135e-5 -2.5353517317026353e-5 5.1337351492211476e-5 -7.279808150550768e-5 -0.0001645620660338506 6.584506527807189e-5 -9.662258625675531e-5 5.8523625652852045e-5; -0.0001425704943658929 -5.082333912745947e-5 -8.786268548014813e-5 8.141534969304558e-5 -0.00012765889944880328 -3.714428051629919e-5 -8.728029600484546e-5 -9.858111600857136e-6 0.00010435530172522305 7.503587009246313e-5 2.8031876051458674e-5 0.0001754445726573446 -0.0001381168370544873 -0.00015192829262836751 7.784003143292979e-5 1.0104462186932247e-5 8.368779860298215e-5 6.751212064900655e-6 2.4642389386786123e-5 -0.00014541130470484436 1.598374880536792e-5 0.00013564894076530822 -3.073102955845072e-5 -4.752965863467199e-5 0.0003712446520393818 -7.954347923536733e-5 -8.125818961438715e-5 0.0001595098907882446 0.00024116698811388684 -7.145216325902355e-5 -6.448926827310451e-5 0.00010523826100956956; -7.062480047757265e-5 1.3345220091227265e-5 2.0429729070689322e-5 -6.571723255043649e-5 9.31145502050051e-5 2.0305484818467615e-5 4.474680431978245e-6 0.00012121402930055437 -0.00012945527608686997 0.00016604368526308076 -0.0002372613938263978 -9.967547278798375e-5 9.139714772167111e-5 -3.611765835964038e-5 -1.8472690812503667e-6 -2.4611719104055276e-5 -0.00013425585105730127 -0.000177009712692286 -1.1804457548774431e-5 -7.139410021827649e-6 5.4826187502417875e-5 -0.00015283749720262232 0.00022035642420226777 -7.400515036774873e-6 4.0365218096978646e-6 -3.359574596694568e-5 5.6684361575694164e-5 2.5725214678036136e-5 3.330441419670492e-5 0.00010200875355215858 -0.00012307971822744175 9.986327010259497e-5; 0.00022561559049624115 -0.00015355080071616654 7.178811680045925e-5 -3.6346957896739096e-5 -2.135954529646106e-6 2.6301392752666318e-5 -9.073871737363249e-5 -0.00015310884449880107 3.593186662170442e-5 0.00015614249164923427 0.0001055941731085624 -2.8114548369364984e-5 -7.250829625424143e-5 4.066113356953703e-5 5.610173426090148e-5 0.00015047367665670935 -7.6307794044389e-5 -7.347864706552515e-5 0.00010241604209864733 8.95618697161105e-5 0.00011275229658917905 -8.057326598804543e-5 -6.389877202467698e-5 7.738338641343531e-5 0.00014597113946207156 4.462178290692246e-5 2.6246384694109927e-5 -2.215726894223039e-5 0.0001476513035942748 -3.020208063548437e-5 0.00017875851077247775 0.0002871969092286526; 8.261944227680898e-5 0.0002382473371323433 -0.00014883858469721812 -0.00012542897687700685 -0.00014107686960079915 6.36273520864779e-5 -1.4884253265608523e-5 6.2131240578034e-5 4.214723000011859e-5 -0.0002196276568941501 -0.00012246179770619925 -0.00011656565269395297 3.651725770731959e-5 7.868928283976893e-5 -1.7969236556673562e-5 1.5733311978402398e-5 -6.252846138813038e-6 -0.00011642065741061515 -9.825096143351367e-5 4.168129222643585e-5 5.601033032115442e-5 2.297037050419191e-5 0.00018613387279736352 -7.001600400719154e-5 -3.9803277592291044e-5 -0.00010276395417924415 -8.686578306543951e-5 0.00011718782976793679 -5.774648756077156e-5 4.4593632504694546e-5 4.920730631391028e-5 -0.00015402529473822896; 0.00015357584963086667 -1.92735766043254e-5 -0.0002458437080215725 6.508835250743969e-5 -5.079456203501376e-5 0.00014306910490150833 -7.43490450852599e-5 -3.6830340455058566e-5 -0.00014685727327131396 1.5345024089322696e-5 -8.025708628819039e-5 0.00010671225792495044 -7.694015911082374e-5 3.823258558869636e-5 -0.00012189819206949984 -7.787619650597405e-5 0.00014410227633080723 -9.490458601508313e-6 -1.2814583519945907e-5 0.00010275536661967109 -6.032602649182708e-5 0.00013220502334180574 0.0002542358564957909 0.00018353946544695003 5.239211454602575e-6 -5.382636626135279e-5 -0.00014838678142522176 -4.51268784538908e-5 8.2643964937884e-5 -0.00019726164426655225 -7.890775994864023e-5 6.399443410802073e-5; 5.708367005769379e-5 0.00013295972802064796 5.2308168037258554e-5 0.00015576954955088686 -4.505175913015411e-5 0.0002609976645404798 -6.142498843055483e-5 4.1891689621537605e-6 0.00016293929273582537 -8.164209522453396e-5 7.204653137217453e-5 0.00011680951045909738 -0.00018415694741214132 -9.177189441776454e-5 7.949677727505111e-5 0.00017195605231303093 -7.014015228866286e-5 -0.0001002341750931055 -4.776263905602124e-5 4.292188440068268e-5 8.989195247034317e-5 -3.0136495311459136e-5 7.246986841404322e-5 0.00019954318030125586 8.09343318766306e-5 -2.0845319354852673e-5 -2.830915277960197e-6 2.0451337136033017e-5 -1.6287359725704128e-6 -0.00016160968608834573 2.6093158325818675e-5 5.93353388327298e-5; -0.00015226701020808008 -2.95241397954233e-6 -5.1973431482573783e-5 -0.00012069753233891881 7.535121017052709e-5 9.017028401056866e-5 -5.290023568749561e-5 6.057826953974933e-5 1.772024711452504e-5 2.2224726989851872e-5 -7.695480258054306e-5 3.463637017330891e-5 -4.356962791010534e-6 -0.00022866730091710907 8.895368568981611e-5 -5.4948665138467256e-5 -8.506585822989523e-5 -7.088871302629749e-5 4.9740537860529906e-5 5.2958541118188936e-5 7.383524257374031e-5 1.9009620142392437e-7 0.00012137805219889255 -9.244800116949893e-5 -1.5082061736862071e-5 -4.6773806616797e-5 -0.00010793580827386313 -0.00010102001056158252 -2.74593169255988e-5 -0.0001510732292904448 8.106490326962035e-5 0.00012254491934936387; 0.0001243069887361582 6.832445277957516e-6 -0.0001050593968745914 -1.59168143643414e-5 0.00023696053861318016 -0.00010546239761492173 -7.532794959002663e-5 1.0542902172569123e-5 -8.458444469464869e-5 -1.5391904954181543e-5 0.0001051241956968228 8.270965511587005e-5 -3.787413558779988e-5 0.00015805504838613827 7.015241456361563e-5 0.00011300581857472662 2.8049809342753222e-5 -6.256310413915364e-5 -0.00011214866796830632 2.1423967636333376e-5 -0.00018489372129874525 -0.00018711380514172907 1.1978219417194518e-5 0.00011945553299415786 -1.8140930906831627e-5 -9.049143088372431e-5 0.00016819848999990094 4.326760089731486e-6 1.3388434608017807e-5 2.1727669529869087e-6 9.308749170418911e-5 0.00011920816498718375; 5.729492495441946e-5 0.0002078846229778215 -4.3494412057094846e-5 0.00018928887602664315 -2.8482416200909873e-5 2.2938167785418568e-5 8.099075075728334e-5 0.00012408442144073127 -1.2563358541279199e-5 8.465626906394736e-5 -0.00010047015512105256 6.349758059508623e-5 8.008570166561426e-5 -5.4524178267961396e-6 -2.0323146418427674e-5 -0.00010606744017427954 8.05892197603902e-5 -7.26884825781412e-5 -0.00010266642475476647 -3.014561098103352e-5 3.44333386237339e-5 3.449545711186533e-5 -0.00010024141356557761 5.248034012760785e-5 -1.345967739330758e-5 -6.676309549225664e-5 -0.0001126575223680045 -0.00011012998566007411 -7.387914211000016e-5 -2.9992155577979868e-5 7.330187527254076e-5 -6.828405439596966e-5; -0.00021625912409142857 0.00013714922756549242 -5.0963224158508886e-5 -3.1333047037367734e-5 -0.0001700353000626274 7.963217872089012e-5 -9.120586542049782e-5 -5.7189144673627385e-5 -6.024538520363873e-5 -3.517730651877878e-5 -2.9816872075791298e-5 0.00013157789585310512 -9.542128605173902e-5 0.00013281167998571247 -6.64569575303585e-5 0.00015152999771885943 0.00012294070843268547 -0.0002619391284963121 -2.5101171150930965e-5 -0.0001362510206956527 -8.739280524536127e-5 -4.8954419572729966e-5 -1.1374610540992429e-5 5.701995764753379e-5 5.651409305642613e-5 -7.457866150837658e-6 -8.12580379272202e-5 -9.654394448374942e-5 -4.072022197869527e-5 9.961526889015447e-5 -5.984854356142493e-5 -2.576036199682231e-5; 0.00010312352925651128 -5.254831602101251e-5 6.280941262488946e-5 -0.0001925057008385489 -3.565018094846183e-5 0.0001073653543379341 0.00013070050000791572 0.00011154940103992176 9.577113770760574e-5 -1.975407074965256e-5 -0.00012561706579176794 -3.4976856554894076e-5 -1.9915506825820968e-7 -9.342637393742845e-5 7.402176517175488e-5 -3.4632332564355127e-6 2.264539354186912e-5 -6.78933069913549e-5 -0.0001251051439659488 2.3022164454998824e-5 -2.412938997761249e-6 -0.0001188674145716066 2.8006340939872592e-5 -2.4723124156618556e-5 6.873401390735027e-5 -8.779573589449815e-5 0.0001975228510808377 0.00016496205309053817 -0.0001036921389506126 5.2646613325894904e-5 -2.6647458512480558e-5 7.878534917367153e-5; -2.2273181541881515e-5 -7.406207086460312e-5 -7.625087536581302e-5 -0.00012175828678010224 6.53662972646949e-5 -7.027433829789599e-5 1.9352391888938716e-5 9.248686902280738e-5 5.646224227093567e-5 1.1191949767810372e-5 5.203190076572038e-5 -0.0001208733193314701 0.00015171709301938063 6.888978887873092e-5 0.00015480642096663303 0.00011571315909095368 0.00011205597903171821 0.00014013579121379993 -1.6475163731958246e-5 0.00011483469634838362 -0.000134964906549437 0.00015307314234377928 6.351822585856916e-5 -2.0374011085372287e-5 -4.268376150576026e-5 -5.204116944516967e-6 4.381444151221906e-5 7.444258872710408e-5 -4.79222836409903e-5 7.502354483876551e-5 -1.269323737942656e-5 4.218499186372574e-6; 0.00013782047105488364 -5.0240787031308695e-5 0.00014111132847615572 6.800188791244535e-5 0.000137177101778848 1.812208799534228e-5 2.5005740284845842e-6 0.00013493533558215723 0.000119325583428198 -2.3485527249671743e-6 9.873478345112447e-5 -2.9897093126752016e-5 -0.0002102884805295202 -3.420571048701738e-8 0.0001223095118534356 5.795419426290904e-5 4.027117253868501e-5 5.047305464398146e-5 -1.5872213707136134e-5 1.479334011532001e-5 -1.4895247231421449e-5 -0.0003092392339836293 0.00016364953850866974 6.294782585469544e-5 2.663227175658774e-5 7.335180244315283e-6 -6.668516064523228e-5 -4.583409602632175e-5 3.660916488154036e-5 -3.391819839593434e-5 -5.7376769772909925e-5 -1.3255667268419578e-6; -9.780057715773204e-5 -6.0414786083258356e-5 -0.0001276545588444975 -0.00012286753170714395 0.00010426266836644441 0.00012344674197434656 -2.9423508556770727e-5 -0.0001235760790115249 -2.423644199383701e-5 -0.0002547099209442795 1.6557129321752835e-5 0.00012967523817848244 6.352364308610137e-5 -4.632411199794857e-5 -2.621862021273258e-5 3.479529206002369e-5 3.088455215389145e-5 0.0001049186905328175 0.00012888349757473206 3.1005365792100256e-5 2.9419570830085753e-5 -2.36311368900185e-5 -4.1387061890547803e-5 5.417114172074279e-5 9.482726474003151e-5 1.3374630918277047e-5 -2.9676926522499126e-5 7.970011230268759e-5 -9.134454720709372e-5 -0.00013081744672321469 2.846145172823188e-6 3.8242866332312844e-5; -2.2937661824363074e-6 0.00014101101638606557 -2.934931384425539e-5 -2.296998703878384e-5 -0.00011301606340476377 -0.00019485893048023074 4.8799444049111024e-5 8.71141784600693e-5 -6.951249437316527e-5 9.315400180730039e-5 3.4608634597127036e-5 0.00011603034741801964 -4.101563571569655e-6 -7.462705601847908e-5 0.0002924396868920585 -0.00022787156543837086 8.828007790815856e-5 -2.4440714550107306e-5 -0.00022373938903820883 0.00015438709320824212 0.00010572557599606914 6.029009106077144e-5 -0.00021323305176248038 7.561759830351508e-7 -0.00014027538039864236 -6.800349714377114e-5 0.0001483088309769228 -5.097789820775541e-5 -7.704896764971033e-5 -2.739257055419161e-5 -4.8338486375550055e-5 -2.0463825275330896e-5; 6.806003143347987e-5 0.00011450164670158647 -0.00010563821347396133 -0.00010747391575214767 -0.0001879078663952032 6.702229797874106e-5 7.656930024304782e-5 -1.4734571825191801e-5 -0.0001327259270075104 -0.00010151580868487551 -5.3501164122376925e-5 -4.177420625726418e-6 8.476984289509202e-5 1.0882981721710066e-5 1.6527927025793934e-5 -0.00011037912559195857 -2.022527046584032e-5 9.288738869863997e-5 2.2295416736500444e-5 1.0436444379502623e-5 5.651234091410101e-5 -0.00010572946125840234 0.00018916237963630353 0.00011986738895565807 -7.772614849464503e-5 -8.996168258100248e-5 -7.742370876449563e-5 -0.00014895790450473746 -0.00010976004346238665 2.166376720915135e-5 0.00019009752936462396 8.812104710877873e-5; -1.0573482536711172e-5 7.251994430116394e-5 8.895741666928906e-5 -1.528065694874583e-5 -0.0002607170881630757 0.00022608680774719472 -0.00023659698308183701 -0.00023000278269610282 -6.485903395078128e-5 0.0001144358411736883 -0.00012215925198252287 -4.129449188913111e-5 1.3903741096144729e-5 -4.734815774572994e-5 3.589913649629818e-5 3.9392818511985456e-5 5.081173913349419e-5 -2.0453888278281055e-5 0.00011141428149567004 -0.0001116061448512497 0.00012415248966571906 -4.045085187972497e-5 -4.005600384989736e-5 8.679356870270359e-5 0.00010031354213860835 7.482635010528399e-5 -2.1860498232128783e-5 -4.3694073606459204e-5 -0.00016589285041409394 0.00020271797439345182 -2.5046370860948065e-5 0.00013079519158491; 0.00022860239045795447 -5.6200221983214705e-5 -1.4618959134286735e-5 -7.699923076336972e-5 -0.0002699280062680681 -4.776906785458524e-5 2.8558161326043294e-5 3.0269595705142415e-5 -2.363032357732102e-6 -0.00023705041308324522 3.747564046158433e-5 9.779497708758455e-5 -5.6110400286456206e-5 -6.522791732471588e-5 0.00014992254222839855 6.845469983624191e-5 7.600863541107514e-5 -0.00011969018446880925 -7.562617752604341e-5 0.00016417435330902806 0.00016028676550001143 -0.00010071618958512313 5.776777329252842e-5 -4.5909926078943434e-5 -0.00010004829305593111 0.00013893839281614782 -4.0980130101284107e-5 -0.00026432898226097317 1.835614279652717e-5 -0.00023622457734212313 -0.00012384797951575595 7.123405743721093e-5; -0.00013993096319994255 -7.270275193539768e-5 -0.00010503211626589903 -0.00020517352925187787 0.00019563841385726808 0.0001020845001000597 0.000165269105196551 -4.666748748082919e-5 -0.00017078528719528945 0.00012124485408614539 -2.0630212696069363e-6 0.0001090041104141546 5.1684200003119096e-5 -6.715746280209342e-6 -4.382851432117998e-5 -9.64027286720766e-5 -2.689747736829585e-5 -0.00022077233836558178 -0.00010975680299842098 -7.786228627470669e-5 1.3134411233951175e-5 3.4779516568704194e-5 -8.170244211604226e-5 5.205154935282151e-6 -1.8575506259625412e-5 0.0001637593730954151 -4.663318861665346e-5 -3.94150231321222e-6 -3.389854202782358e-5 -0.00013647029947995383 5.4765508827727165e-6 2.642133483798827e-5; -5.2434550690503574e-5 -0.00011100297177239447 5.535775364347858e-5 -7.715374720777338e-5 -5.202171285549184e-5 0.00020800401927641455 9.317154564866681e-5 0.0001670532126978009 9.244896302203503e-5 4.093624904140886e-5 -7.875430508778494e-5 -8.092657133732134e-5 -6.213133218028224e-5 -8.302515396411647e-5 3.621184243325931e-5 -4.350114233091982e-6 -2.5292985975682804e-5 3.878105582641414e-5 -7.315243618126145e-5 -9.228941251552074e-5 -0.00015430048568435898 -4.9420703527574e-5 -0.0001105819848648379 8.487271208519998e-5 0.00014192596956360022 4.585153675958425e-5 8.392335241165884e-5 -9.405841575643028e-6 0.00010346142318537761 -0.00013767076068819568 -0.00010506186861806698 -9.110469744103553e-5; -0.00013296077800190105 0.00012167176332583163 -0.00011998149552822535 -0.000175700538831041 0.00017074189831761186 -5.7792262407371355e-6 -0.0001183994622362989 5.792094320783094e-5 7.707914351045595e-5 0.00011230823160873284 0.00012931390533482183 -2.461368008083455e-5 1.530573585911289e-5 1.2435955385409574e-5 -6.642942885377004e-5 7.13358854114763e-6 6.134947629903893e-5 -6.766826250098874e-5 -2.2894009317713974e-5 -3.488000352375517e-5 -7.770674842616987e-5 0.00010025659735107325 -3.114558189848926e-5 -7.13741842002841e-5 -3.084366603728871e-5 -2.3316989837656845e-5 5.27557372595416e-5 1.4344914733338981e-5 0.000122945681640672 -5.998198817109231e-5 -8.078295054587096e-5 2.3206219838141663e-5; -0.00011633543503675287 -0.00010606519530967889 0.00012713684232465355 6.750840516786313e-5 7.073441197888126e-5 -0.00023916044624031616 0.00013134356818654287 -9.77328123056485e-5 -9.953008297572952e-5 3.818940281660538e-5 -2.288702084599496e-5 -6.578732022290457e-5 -2.6991150248536597e-5 -0.00012622987155311803 -1.0962298388517167e-5 1.2589449811163081e-5 -0.00014802516265747167 -0.00011233051513527647 2.5269110169279666e-5 9.11328501885595e-5 6.473491926470065e-5 -0.00019960836975344403 0.0001650269445732879 0.00010152290310968702 1.1047350923722564e-6 3.1967927911570146e-5 4.9140752211930656e-5 -5.8378372655464507e-5 -8.62138509049129e-5 4.49370895282088e-5 -3.8863377581333066e-5 8.625613959764988e-5; -4.4958016328259124e-5 2.3045106250100115e-5 0.00011451334129134996 3.59269563753697e-5 0.0001650782126355985 -0.00012255655903089214 -6.723305133589068e-5 -0.00010632935456185293 -7.435464950929756e-5 5.02170098892022e-5 2.6974594480015246e-5 -8.702121860066911e-5 -5.419801683699636e-5 -4.7502902557136125e-5 -3.564384470763044e-5 -9.222205127542228e-5 -1.0649986247409552e-5 -0.00012119543108808254 9.329423327046107e-5 -6.742740671568274e-5 -0.00013777693103718766 9.212860303279663e-5 -4.242522637894727e-5 6.820242404984962e-5 -8.587971043940354e-7 4.5853789092979065e-5 -2.001277581683178e-5 2.0583678893052634e-5 0.00011201117584762999 -9.964180363993458e-5 4.406275391858639e-5 0.0001154959666431055; 3.328329974504672e-5 2.7134490941791353e-6 0.00010287144239594347 -1.2391066702158562e-5 7.685280325339154e-6 4.903574863809069e-6 -4.275174359981423e-6 -5.5969417116620993e-5 -4.024632373291009e-5 -5.06496443627578e-5 3.7872104331134666e-5 -6.278320954112996e-5 -0.00016820516331751657 -0.00015453008829315127 3.13536004835275e-6 -9.944129571483402e-5 -4.3236600431165905e-5 -2.791600938468869e-5 3.945387932656397e-5 4.87188290939503e-5 -7.98725333840716e-6 0.00016683872952892527 0.00012160265312598147 9.059478459714991e-5 0.00011591599748537936 -1.4767414467834617e-5 -2.585727899674713e-5 0.00015263658413498736 -0.00019477397738189056 -0.0001185400441677954 4.743865709354227e-5 4.841525795239211e-5; 0.0001531604836372051 -1.5832590475080652e-5 -1.5593553439573156e-5 -0.00018695136107013254 0.0001206315834450269 0.00014692028769316758 -0.00019783537875367148 0.00023188270704382557 -4.9151603275429216e-5 0.00011536243680424402 -7.037446932842092e-5 -2.2245403056451212e-5 -0.00010129833692658234 -0.0001133488215829551 5.6942459876451084e-6 -7.616581144717708e-5 -5.5644391363935655e-5 -0.00014104338636127732 -3.1862676519418896e-5 -0.00011037259839680652 -0.0001089529644788491 -8.474092767315942e-5 -3.3011321034034154e-5 5.5173242570492845e-6 -1.7034598681827386e-5 -0.00014570763898988416 -0.00010009108821137856 -0.00013691582291827702 -0.00019268418993771314 0.00010581948724833224 3.897527919874158e-5 0.00020689286309321105], bias = [-3.545083550533263e-9; 2.31262532535305e-9; 1.0073268774503124e-9; 2.13936591062046e-9; -1.3923731694890866e-9; 1.3089062980578348e-9; -2.4166139082769433e-9; 1.6040457519010004e-9; -8.561387093406967e-10; 5.490763451835097e-9; -2.294233418313196e-9; -2.179015494155059e-10; 4.291891525808786e-9; -2.6086586474300424e-9; 1.8367511601320225e-9; 4.1491143257914804e-10; -3.897686151669041e-9; 9.454332476492823e-10; 3.7142919605133392e-9; 2.847146296538487e-9; -7.055752051867177e-10; -1.2185764799578876e-9; -1.0223194353005597e-9; -1.1459196901874318e-10; -2.7553634368178576e-9; -3.0205449406682574e-9; -1.1734131831750888e-9; -2.9235819847055696e-10; -2.0623204589666545e-9; -1.0129176023484006e-9; -2.736847562511537e-10; -4.0848364438463534e-9;;]), layer_4 = (weight = [-0.0007225265040813575 -0.0005472821399113828 -0.0008314028004748516 -0.0005487975886230616 -0.0007807571538695922 -0.0007071450532340547 -0.0006756196531636345 -0.0006106549222079432 -0.0007109791910256437 -0.0007364006234637578 -0.0007361333063461191 -0.0005701093336886052 -0.0008812051265632083 -0.0005713179349483784 -0.0006474175137200081 -0.0006579283065006515 -0.0005623261160788136 -0.0007128806804228791 -0.0007243905771776643 -0.0005761841760295906 -0.0005888091340428618 -0.0006877250880482119 -0.0006308852091881862 -0.0007599775218331967 -0.0007794552291696068 -0.0006937656585647561 -0.0008325764717962617 -0.0009430102514923222 -0.0007101651724410151 -0.0008587315225691966 -0.0005207161312058312 -0.0006479826333197058; 0.00033435599212075165 9.107746943896243e-5 0.000306954140781786 9.348540765200025e-5 0.0002686219684227278 0.00016250558781924414 0.0003982923043493587 0.000274089535063 0.00023347738518830966 0.0003183739465348586 0.0003573822456892075 0.00031734907380479645 0.00032919391222671585 0.0002187956030389761 0.00025473549820363196 4.080474539743468e-5 6.362002426227138e-5 0.0004064085004238355 0.00017329806295601923 0.0002082355806429911 0.0001918616055975271 0.0001380783039460849 0.00022751448088951357 0.0001299608019473249 0.0003342862262763715 0.0002368188086447873 0.000166067999956632 0.0002382672486885848 0.00024248116185109078 0.00024905101863435727 0.00022339130220890545 0.00024662126688168833], bias = [-0.0007259725265166863; 0.00021988771638519657;;]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(compute_waveform(
    dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(ax, tsteps, waveform_nn_trained; marker=:circle,
        alpha=0.5, strokewidth=2, markersize=12)

    axislegend(ax, [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb)

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(LuxCUDA) && CUDA.functional()
    println()
    CUDA.versioninfo()
end

if @isdefined(LuxAMDGPU) && LuxAMDGPU.functional()
    println()
    AMDGPU.versioninfo()
end
Julia Version 1.10.3
Commit 0b4590a5507 (2024-04-30 10:59 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 8 default, 0 interactive, 4 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_AMDGPU_LOGGING_ENABLED = true
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 8
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.4, artifact installation
CUDA driver 12.4
NVIDIA driver 550.54.15

CUDA libraries: 
- CUBLAS: 12.4.5
- CURAND: 10.3.5
- CUFFT: 11.2.1
- CUSOLVER: 11.6.1
- CUSPARSE: 12.3.1
- CUPTI: 22.0.0
- NVML: 12.0.0+550.54.15

Julia packages: 
- CUDA: 5.3.4
- CUDA_Driver_jll: 0.8.1+0
- CUDA_Runtime_jll: 0.12.1+0

Toolchain:
- Julia: 1.10.3
- LLVM: 15.0.7

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: Quadro RTX 5000 (sm_75, 15.729 GiB / 16.000 GiB available)
┌ Warning: LuxAMDGPU is loaded but the AMDGPU is not functional.
└ @ LuxAMDGPU ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/LuxAMDGPU/sGa0S/src/LuxAMDGPU.jl:19

This page was generated using Literate.jl.