Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux, ComponentArrays, LineSearches, LuxAMDGPU, LuxCUDA, OrdinaryDiffEq, Optimization,
      OptimizationOptimJL, Printf, Random, SciMLSensitivity
using CairoMakie

CUDA.allowscalar(false)

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end
one2two (generic function with 1 method)

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end
soln2orbit (generic function with 2 methods)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end
d_dt (generic function with 1 method)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0)<1e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio1 "mass_ratio must be <= 1"
    @assert mass_ratio0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit1, mass1, orbit2, mass2)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end
compute_waveform (generic function with 2 methods)

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; markershape=:circle,
        markersize=12, markeralpha=0.25, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defiing a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(Base.Fix1(broadcast, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1e-4)),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1e-4)),
    Dense(32 => 2; init_weight=truncated_normal(; std=1e-4)))
ps, st = Lux.setup(Xoshiro(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-3.7090495f-6; -0.000112906906; 9.121012f-5; 0.00013011461; -0.00019051759; 9.791001f-5; -0.00017315523; 8.074661f-5; -8.143839f-5; 4.9919767f-5; -0.00011104929; -3.756932f-5; -8.545672f-5; 4.145344f-5; 2.53183f-5; -3.3249813f-5; -3.2966374f-5; -4.547579f-6; -0.00010109941; -0.00010161676; -4.226429f-7; 6.6335733f-6; -0.00012566015; -8.305307f-5; -0.00019378384; 0.00016258833; 0.00012272003; 0.00011263117; 0.0001146841; 0.00014358909; 0.00011572713; -8.055819f-5;;], bias = Float32[0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0;;]), layer_3 = (weight = Float32[-8.681742f-5 1.1953795f-5 0.00010715996 -8.308292f-5 -0.00011259235 -9.309397f-5 6.2254054f-5 -6.769662f-5 -0.00016840523 -5.4019565f-5 2.6072155f-5 -2.8465623f-5 0.00010106866 -3.0687886f-5 -0.00011711822 -0.00010991415 8.987795f-5 -8.041053f-5 1.3697735f-6 -0.00013814097 -0.000196295 -8.226939f-5 -0.00013082367 -0.00020371508 -9.398931f-5 -8.942611f-5 0.00017659839 -9.3502995f-6 -0.00010591908 -1.1122162f-5 -4.9817f-5 0.00011160792; 0.00010110031 -3.2140644f-5 -0.000118786484 7.598621f-5 -0.00021104183 -7.268992f-5 -1.5337496f-6 1.543101f-5 0.00016243887 0.00022392622 0.000112157235 0.00020663635 1.5846615f-5 1.3199119f-6 0.00010076237 4.9832848f-5 -0.00019492928 -5.414747f-5 -0.00011022435 5.1728002f-5 7.402235f-5 0.00018896912 -0.00016936212 -1.8287565f-5 4.624603f-5 -6.0092123f-5 6.3500695f-5 4.849502f-5 3.361846f-5 0.00028407422 0.0001144317 0.00012609284; -5.7808968f-5 -0.00013884887 -0.00015048636 0.0001832897 0.00016434606 -0.00010018818 -6.262455f-5 7.674989f-6 -0.00010427912 0.0002173652 8.167427f-5 0.00014249625 0.00020902247 8.791538f-5 7.20736f-5 1.4174836f-6 -0.00012735318 -0.000147696 -0.00015657992 5.979992f-5 5.0760133f-5 2.7759743f-5 -0.00013477371 -0.00011382677 -5.048339f-5 1.6777867f-5 2.7351985f-5 -8.3982064f-5 -1.19252045f-5 0.0002690386 9.2352464f-5 -2.7454584f-5; 1.4755375f-5 -8.426129f-5 -2.7090446f-6 3.696872f-5 6.821808f-5 -6.851097f-7 0.0001085061 9.13431f-5 0.00016152454 0.00013308304 -2.1248588f-5 -1.4764769f-5 0.00013471022 -6.522942f-5 8.0290585f-5 -0.00014950252 -6.5920416f-5 2.7519005f-5 -5.3942098f-5 -0.00015586051 -4.3518772f-5 -1.6987548f-7 -6.4808655f-6 4.1132575f-5 -6.60728f-5 3.4447003f-5 0.00013884316 -0.00011263939 2.0685926f-5 -4.0382f-6 -3.1164487f-5 2.0008927f-5; -0.00034580054 -5.565761f-6 -4.3846776f-5 -9.2424f-5 -4.9403137f-5 6.9937865f-5 -6.0118557f-5 -1.3859264f-5 6.929537f-5 3.483419f-6 3.1559375f-5 -3.591251f-5 4.5905344f-5 9.4376715f-5 0.00016551877 -4.651011f-5 -0.00023693174 0.00011365839 5.2485808f-5 -1.9568239f-5 -6.763457f-5 1.7317467f-5 0.00012141837 -8.476012f-5 -7.7501645f-5 3.4765002f-5 -8.433988f-5 1.454425f-5 -8.236019f-5 -3.986479f-5 -2.802443f-5 -6.839656f-5; 0.00010394678 0.00016244932 6.633131f-5 5.211691f-6 8.8802626f-5 0.00017758812 0.00025375208 -0.00010280256 6.1079096f-5 3.950997f-6 -4.698021f-5 -4.7932583f-5 -0.00021001561 1.8178158f-5 0.00017734186 0.00011339601 4.3998116f-5 5.064963f-5 1.44430605f-5 -5.7986148f-5 0.00015236158 9.695923f-5 -0.00018106651 -6.2658546f-5 5.753683f-7 -2.7285794f-6 0.00013374163 -1.579728f-5 1.2369233f-5 0.000111194495 1.7850418f-5 -3.4851386f-5; -0.00021086953 0.00021169739 8.280408f-5 0.00013797052 -8.510385f-5 -8.809881f-5 -7.043415f-5 -0.00011623214 1.2136932f-5 8.028604f-6 7.5356393f-6 0.00016929056 -0.000105428626 -0.00016996157 8.366126f-5 5.777699f-5 -0.00014327392 8.025861f-6 1.0570719f-6 7.147398f-5 -0.00025730202 5.7745976f-5 1.1637128f-5 7.6564735f-5 -1.5447838f-5 0.00021524967 5.3557025f-5 0.00022177152 -6.127087f-5 3.4555353f-5 0.00014511835 4.897324f-5; 8.175289f-5 0.00015657517 0.00015745153 4.652048f-5 0.00024794586 4.9673446f-7 -4.8457197f-5 -4.852245f-5 -6.74543f-5 -8.7658336f-5 -9.0495676f-5 2.8762963f-5 -2.6068014f-5 5.4468397f-5 1.6806354f-6 -0.00012612424 9.83133f-5 -6.639149f-5 0.00013596735 9.7779484f-5 -0.00019246044 0.00016396948 3.504468f-5 1.9486844f-5 -1.454339f-5 0.00013316316 -0.00024422866 8.7519125f-5 -0.0001335312 7.5087746f-6 6.136154f-5 -6.822704f-5; -2.545089f-5 0.00010519635 0.00014386929 0.000117897085 -0.0001594567 -4.61676f-5 -0.00020111025 7.873897f-5 -0.00014260178 -3.8196642f-5 -0.0001102902 1.0730551f-5 8.299172f-5 7.0368915f-5 1.7818236f-5 0.00012273728 -8.001158f-5 -1.07865135f-5 -5.6473702f-5 -0.00010031742 5.2898296f-5 1.0885642f-5 -1.01212545f-5 -8.57822f-5 -8.715929f-5 -0.00010163351 2.4197041f-5 1.5008553f-5 -2.3865668f-5 4.24447f-5 -8.1361686f-5 0.00012188513; -1.0853375f-5 -0.00012917901 -8.178341f-5 -4.9950468f-5 -1.6078196f-6 -2.4463234f-5 -6.48672f-5 0.00011854046 -0.00016701779 0.00010656745 -0.00012960251 1.8498691f-5 -0.00021187244 7.2603405f-5 -0.00014740867 0.00016342575 -0.00014089463 3.9493883f-5 -0.00014319208 -0.00018864895 0.00014325175 -7.6922785f-5 3.0018638f-5 -0.00012054979 0.00012688208 -4.0962932f-5 0.00016227544 0.00019549277 4.041729f-5 -2.4962645f-5 -0.00012868913 1.3790525f-6; -7.8374804f-5 -0.000112382295 6.4980086f-5 4.2328047f-6 0.00013774814 3.926343f-5 -0.00012331034 -0.00021519381 -2.101757f-5 4.3092474f-5 -5.1893043f-5 -0.00010654842 6.310718f-5 -8.770099f-5 3.2946453f-5 -3.5490037f-5 0.0001356569 6.2137115f-6 -0.00016339564 -9.515137f-5 0.00013962593 1.3231626f-5 -0.00017753059 -4.927631f-5 0.00018200887 -0.000111540074 -7.816596f-5 6.031054f-5 -4.026133f-5 -0.00013119551 -0.00015108012 -2.905789f-5; -6.889106f-5 -5.2823492f-5 9.792333f-5 -0.00012980057 -7.9324134f-5 -0.00014629641 -7.694081f-5 6.681757f-5 0.00016201356 -4.8013313f-5 0.00025182497 4.0265688f-5 2.8878383f-5 3.85107f-5 2.9499244f-5 -0.0001385601 0.000100752746 5.7136745f-5 6.101586f-5 -3.775728f-5 -0.00022723804 -6.0688162f-5 6.9155045f-5 9.997962f-5 -0.00019921419 0.0001240463 9.358553f-6 -2.7579745f-5 0.00015752655 0.00017333394 -7.462778f-6 1.38037385f-5; -2.811281f-5 -0.00015431283 4.720491f-5 -0.00010623294 -7.0089576f-5 -8.7133216f-5 -6.986576f-5 -0.00011344189 -2.1413478f-5 0.00013285731 -7.601069f-5 -4.185225f-5 -0.00022417292 -5.0983053f-6 0.00011157384 -5.547629f-5 -8.196038f-5 -2.9131677f-5 5.099293f-5 2.1994985f-5 4.628682f-5 4.6778598f-5 -0.00018167713 -0.00026135717 -0.00016970113 9.1180955f-5 -6.671423f-5 4.4532648f-5 -8.7167355f-5 -2.476722f-5 0.00017731458 -2.77152f-6; 4.1009276f-5 0.00023882791 7.93112f-5 3.110778f-5 -4.124147f-5 -0.000102587706 -6.890517f-5 2.6575883f-5 5.8320813f-5 7.297792f-5 -8.027246f-5 5.9410275f-5 -0.0001820034 0.00010146975 -3.2810796f-5 0.00014192681 0.00012037565 4.4699605f-5 2.8657185f-5 3.0011559f-5 -0.00010762064 7.0687536f-5 -0.00024256537 4.053539f-5 -0.0002664685 0.00015984717 -0.0001059718 -8.260303f-5 4.465097f-5 0.00012463093 -9.8051576f-5 0.000111182744; 6.151414f-5 5.713984f-5 0.00019483129 -0.00010608005 -6.9251386f-5 0.00015613083 -7.374048f-5 -0.00012754311 8.801124f-6 0.00010866489 7.697642f-5 -1.1426925f-5 -5.065199f-6 -0.000109447545 5.856711f-5 6.5505905f-5 1.8553865f-5 1.8521362f-6 0.00012014684 5.9555427f-5 7.45781f-5 5.6222416f-5 -0.00010294075 3.5295325f-5 0.0002509793 0.00010666124 -9.0184f-6 -0.00011005037 9.56684f-5 -3.0360077f-5 7.978498f-5 -1.6408929f-5; 8.194295f-5 -2.8632496f-5 -3.6093284f-6 0.00012564301 -2.9386609f-5 -9.8870405f-6 9.03286f-5 -2.1390899f-5 -3.924119f-5 -8.810224f-5 7.509734f-5 2.8049697f-5 2.6746224f-5 -8.651917f-5 1.2966826f-5 -8.789904f-5 -1.4152408f-5 0.00012537758 0.000134672 0.000104479615 2.431241f-5 8.5998385f-5 0.00017886587 -0.00013944099 0.00011804696 1.3809546f-5 -0.00021141955 3.475662f-5 -3.622948f-5 -2.241332f-5 3.0626703f-5 2.5467018f-5; 7.216489f-5 -0.00012900318 5.694839f-5 5.641296f-5 -8.478518f-6 4.7448077f-5 4.0836974f-5 -4.4486922f-5 2.440442f-6 -7.443454f-5 -0.00014937662 -0.00014152724 7.8894245f-6 -6.7883184f-6 -0.00012269773 5.456745f-5 -7.2862706f-5 0.00011986866 0.00017193505 6.26708f-5 -9.644269f-5 -8.7138935f-5 2.522146f-5 5.119194f-5 -4.4324537f-5 4.028435f-5 8.364928f-5 8.941169f-5 -4.68505f-5 -1.8657873f-5 3.0292138f-5 -0.00014705997; 5.314895f-6 0.00010540193 -0.00015732409 6.241365f-5 1.409876f-5 0.0002622695 -7.63117f-7 -6.9578764f-6 -0.00011278468 -6.873307f-5 -4.1518084f-5 1.1686574f-5 0.00010328293 2.4198856f-5 7.212031f-5 -6.388013f-5 -2.2708893f-5 -5.9628856f-5 -8.752324f-5 0.00020093065 0.00013446866 -7.019735f-5 -7.428874f-5 6.0525366f-5 8.812658f-5 -0.00018491282 4.468222f-5 0.00022693821 -7.845646f-5 0.00020721364 7.855074f-6 4.5476852f-5; -2.3371049f-5 2.8238948f-5 8.44367f-5 -0.00022490088 2.7435406f-6 3.4001972f-5 0.00013888112 5.0091003f-6 -5.5446995f-5 -0.0001255068 0.00011159734 -7.512412f-5 0.00010185609 -0.0001739578 -0.00012948002 -6.0169903f-5 5.4865445f-6 -6.0948347f-5 5.964101f-5 -3.400834f-5 1.22018955f-5 -7.90831f-5 -4.5078592f-5 8.376678f-5 8.0198886f-5 -1.4101281f-5 0.00016897787 8.0497215f-5 1.750198f-5 -4.2321346f-5 -4.2371354f-5 0.0001131803; 0.00025280452 -2.714492f-5 -1.26979f-5 -6.978491f-5 2.2493057f-5 8.280012f-5 -0.00014780223 -1.9762192f-5 -2.500431f-5 -4.958074f-5 -0.000101497026 5.0278926f-5 2.0691452f-5 9.941359f-5 2.1321977f-5 0.00021658294 -0.00011038208 5.714876f-5 3.0194138f-5 2.9620065f-5 -0.000104757746 -2.4041596f-5 -0.00020974783 0.00013176004 -1.5962261f-5 -0.000122836 -2.9096169f-5 1.2560204f-6 7.78512f-5 3.1583277f-5 -3.903629f-5 1.32454015f-5; 4.0382045f-5 -6.687556f-5 0.00023119649 -3.6016053f-7 -2.3728493f-5 -0.00010336594 -5.8964666f-5 0.00016770007 -0.00012970397 0.00012703006 8.2029685f-5 3.875446f-5 8.416053f-6 9.104297f-5 0.00018846634 -6.474173f-5 3.9056915f-5 2.678328f-5 -9.05898f-6 -0.00013703665 -2.8689876f-5 9.170536f-5 9.643583f-5 0.0001013675 -0.00018670266 0.0001281381 5.723274f-5 -6.985026f-5 5.042918f-5 0.00017201457 -6.945722f-5 -0.00019873008; 0.00011840272 -0.00018281586 -8.085338f-6 -0.000114027614 -1.9691597f-5 -0.00014649521 0.00012466838 8.757997f-5 0.00030809356 -6.51461f-5 0.00012747473 -9.889161f-5 -8.19778f-5 -0.00018020373 -9.971686f-5 -4.988758f-5 0.00010313153 -4.150747f-5 0.00014201977 5.0489212f-5 -2.7597096f-5 0.00012924556 0.00012679161 -3.466401f-5 0.00017229588 -7.526101f-5 1.4235609f-5 5.0562205f-5 9.156624f-5 -0.00013004929 0.00013618059 -7.448611f-5; -3.9340407f-6 2.4759824f-5 0.00020226564 0.00016160692 6.505699f-5 5.2484127f-5 -7.416066f-5 2.043867f-5 -3.7390412f-6 5.9911818f-5 0.000104783074 5.811357f-5 6.024747f-5 7.86578f-5 5.3223775f-5 0.000108952445 3.8660244f-5 9.667801f-5 0.00022622841 -0.00010545259 0.00013061814 -8.105311f-5 6.9109246f-6 -5.793272f-5 0.00024415914 0.00017145094 -9.873154f-5 -2.8018167f-5 -7.812979f-5 -0.00011074599 -3.9501597f-6 0.00012168358; 0.00024545882 -0.00012564147 0.00014229477 -0.00021461344 -0.00012228415 -6.9822614f-5 -8.555772f-6 -0.00011408073 0.00012750144 -3.6505447f-5 7.7840596f-5 -1.3515818f-5 -9.1582784f-5 3.5173027f-5 4.8769547f-5 0.000109290006 -2.3202f-5 -6.5331325f-5 -3.487806f-5 3.334526f-5 -4.559674f-5 9.817268f-5 -0.00025642876 9.271681f-5 -0.00027783483 -0.00016362418 3.1208296f-5 -7.096567f-5 -6.415798f-6 -1.9855715f-5 0.00024786548 8.744505f-6; 7.211605f-5 0.00012301354 -8.2521205f-5 -0.0001263249 -0.00013772438 4.961994f-6 8.9700465f-5 -6.8658504f-5 -1.1476328f-5 3.758845f-5 0.00015295781 1.7926752f-5 0.00011865484 -0.00019239023 -1.3022018f-5 -6.4974697f-6 0.00014625382 1.4175376f-5 0.00016034915 -0.00015811335 9.880131f-5 5.5428805f-5 0.00032604838 -0.0001189528 -5.0487295f-5 1.6045811f-5 -7.2142655f-5 -4.374798f-5 0.000114474555 3.083645f-5 -6.360411f-5 -8.787862f-5; -0.00015718164 -0.00013256459 0.00010390622 -0.00013066338 -5.380833f-5 9.8112505f-5 -0.0001062218 -9.5050935f-5 -0.000116765594 7.254451f-5 -9.5531264f-5 4.0165447f-5 -5.982197f-5 -7.161505f-5 8.159179f-5 -1.33203f-5 5.9567094f-5 0.00010057425 3.767484f-5 -9.392254f-5 -1.5189761f-5 -6.896647f-5 -7.531927f-5 -8.952318f-5 -1.7392378f-5 0.00016475619 -8.851672f-5 -1.6868911f-5 9.546133f-5 -0.0001390619 1.0259094f-5 7.485221f-5; 7.836018f-5 -5.1448103f-5 -0.00022476113 -1.0844748f-5 2.2936816f-5 0.0001464918 -1.3959495f-5 -6.8406815f-5 -0.00017202055 -0.0003730353 7.967585f-6 -0.00011042028 5.7887402f-5 -4.832778f-5 -6.210573f-5 0.00016352016 -9.532411f-5 6.3706298f-6 -6.787228f-5 7.695039f-6 1.2406798f-5 -7.571284f-5 -8.513476f-5 0.00014600693 -7.038365f-5 6.644303f-5 5.7982754f-5 3.9307357f-5 8.416986f-5 1.7834585f-5 1.0704901f-5 0.00014045871; -5.1166866f-5 0.000118130134 -1.9062245f-5 0.0001288369 2.289677f-5 -1.6258213f-5 -0.00018586451 -4.932245f-5 -2.2379585f-5 -8.9918365f-5 -6.714307f-5 -6.736823f-5 -0.00017617176 -2.6530784f-5 3.0289197f-5 0.00011638353 3.0797986f-5 -4.904833f-5 2.7100032f-5 0.000101551355 3.7605798f-5 -0.00013092897 -9.8919045f-5 -9.242533f-5 -0.00013284868 -5.978484f-5 8.366794f-5 -6.4631924f-5 -3.196996f-5 -3.6547117f-5 5.672231f-5 -3.947497f-5; -5.560131f-5 -8.2061815f-5 -2.0887528f-6 8.910148f-5 -3.6316716f-5 -0.00010124355 -1.4694341f-5 8.534235f-5 -7.031141f-5 1.1211088f-5 2.5200778f-5 0.00010802393 4.8337082f-5 -9.587753f-6 -9.206158f-5 -0.00013790032 -1.99404f-5 6.260721f-5 0.00013002381 0.00019311045 7.303784f-5 0.0001445188 -8.3771025f-5 7.137857f-5 -4.041815f-5 3.355392f-5 5.9870792f-5 -5.218674f-5 8.7080596f-5 0.00010186078 4.1814532f-5 -0.000111262336; 6.606352f-5 -6.150248f-5 8.210773f-5 -9.016514f-5 -0.00011009504 0.0001550888 -0.00011484237 -7.420078f-5 0.00010323086 3.9132124f-6 -6.391634f-5 -3.731381f-5 5.9608403f-5 4.607404f-6 -0.00023805685 2.6621778f-5 2.6015246f-5 -9.1109614f-5 -5.9198832f-5 9.698664f-6 5.7506928f-5 2.2875802f-5 -0.0001901133 7.205064f-5 9.734149f-7 3.724733f-5 -7.2008814f-5 -3.4994006f-5 6.4695764f-6 -4.2347427f-5 4.9663086f-5 5.496611f-5; 5.0384915f-5 0.000118473254 3.3816132f-6 -6.851704f-5 -1.2578998f-5 3.6787533f-5 -3.8275384f-6 7.4819254f-5 -3.9932074f-6 -5.4376884f-5 -3.012484f-5 2.6639766f-5 4.2870353f-5 -0.00018274917 -0.00016184676 9.608435f-5 1.2335661f-5 0.00023282038 -7.778209f-6 0.00011903286 0.00011334837 5.0333594f-5 1.4975886f-5 -0.00012922482 -2.8382334f-5 6.8171634f-5 -0.00019589791 -0.000119898636 -3.4977693f-5 6.930717f-5 0.00016328879 4.393619f-5; 3.443207f-5 0.00011420038 8.3658524f-5 7.691841f-5 0.00012734943 8.626818f-5 5.0087732f-5 3.2113287f-5 -5.088789f-5 -9.776833f-5 -0.00012338934 -4.8360827f-5 -0.000102779544 0.00011946895 -6.271452f-5 3.7976715f-5 -7.302555f-5 -0.00013380531 5.9041875f-5 5.3997923f-5 0.00014882041 0.00020540817 9.2584505f-6 -9.9955774f-5 0.000117808224 5.2494604f-5 2.2330842f-5 -2.5898704f-5 3.09403f-5 -7.287721f-6 6.416347f-5 5.2210657f-5], bias = Float32[0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0;;]), layer_4 = (weight = Float32[0.0001520579 -0.00010501084 3.00929f-5 -1.9047699f-5 -7.5644384f-5 -0.00010074411 -4.6865724f-5 2.2768796f-5 1.4279139f-5 6.708161f-5 -9.80768f-5 3.8763856f-6 -0.000110074005 0.00016275895 4.5885874f-5 3.0104813f-5 -3.42877f-5 -7.2143004f-5 2.9812692f-5 -0.00016606944 0.00014338488 -0.00010273319 -2.9174448f-6 6.226655f-5 -3.6669273f-5 -8.180466f-5 6.6031374f-5 -6.4420124f-7 -0.00016366721 8.533078f-5 5.4962557f-6 -9.168456f-6; -1.5360089f-5 -1.4792761f-7 0.00012271854 8.798139f-6 -9.6559335f-5 3.7150752f-5 -0.000109552144 -0.000110902096 -2.8779235f-5 -0.00011555239 -8.7334f-5 -1.4908292f-5 0.00014913059 1.6136699f-5 0.0002738758 1.8299825f-5 4.3418375f-5 -5.441717f-7 -3.8125167f-5 0.00015764602 -4.3840675f-5 0.00010458747 -0.00019799397 0.00016922835 4.385056f-5 -9.336735f-5 -5.2068117f-5 -0.00020108734 -1.3076445f-6 0.00010695125 0.0001460799 0.00014052542], bias = Float32[0.0; 0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most DL frameworks, Lux defaults to using Float32, however, in this case we need Float64

julia
const params = ComponentArray{Float64}(ps)

const nn_model = StatefulLuxLayer(nn, st)
Lux.StatefulLuxLayer{true, Lux.Chain{@NamedTuple{layer_1::Lux.WrappedFunction{Base.Fix1{typeof(broadcast), typeof(cos)}}, layer_2::Lux.Dense{true, typeof(cos), PartialFunctions.PartialFunction{nothing, nothing, typeof(WeightInitializers.truncated_normal), Tuple{}, @NamedTuple{std::Float64}}, typeof(WeightInitializers.zeros32)}, layer_3::Lux.Dense{true, typeof(cos), PartialFunctions.PartialFunction{nothing, nothing, typeof(WeightInitializers.truncated_normal), Tuple{}, @NamedTuple{std::Float64}}, typeof(WeightInitializers.zeros32)}, layer_4::Lux.Dense{true, typeof(identity), PartialFunctions.PartialFunction{nothing, nothing, typeof(WeightInitializers.truncated_normal), Tuple{}, @NamedTuple{std::Float64}}, typeof(WeightInitializers.zeros32)}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}, layer_4::@NamedTuple{}}}(Chain(), nothing, (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()), nothing)

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(ax, tsteps, waveform; markershape=:circle, markersize=12,
        markeralpha=0.25, alpha=0.5, strokewidth=2)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(ax, tsteps, waveform_nn; markershape=:circle,
        markersize=12, markeralpha=0.25, alpha=0.5, strokewidth=2)

    axislegend(ax, [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"]; position=:lb)

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    loss = sum(abs2, waveform .- pred_waveform)
    return loss, pred_waveform
end
loss (generic function with 1 method)

Warmup the loss function

julia
loss(params)
(0.17800437671059263, [-0.024267617148096984, -0.023482534249793213, -0.022697451351489286, -0.021370929424605563, -0.019475238868439264, -0.016970130283282342, -0.013802762791225434, -0.00990474536309978, -0.00519441191536813, 0.00042351562634138193, 0.00704869522196032, 0.014764165577191245, 0.023588623577619292, 0.033347789983459476, 0.04339558061303184, 0.05195496168361476, 0.05472400786852236, 0.04260448282779364, 0.0020441104475288083, -0.0662865111396386, -0.1103973415118011, -0.0763371970176771, -0.006743114921994411, 0.038957312755559234, 0.054379413599510915, 0.05299295532543121, 0.04483057549755673, 0.034773247927447584, 0.024860054478857096, 0.015848352861360517, 0.00795084248140009, 0.0011650186713889586, -0.004590750665755806, -0.009418071976421052, -0.013415470863880457, -0.016667287846435147, -0.019245177899249337, -0.021205106302130534, -0.022589551985236912, -0.023428421908948685, -0.023740042005191513, -0.02353152155082297, -0.02280004554161169, -0.021530540439143066, -0.019697818518264614, -0.017264236598703967, -0.014178961571078464, -0.010377435083952662, -0.005780632427187352, -0.00029665698305362707, 0.006171699426036556, 0.013710929884887603, 0.022348311536654664, 0.03194788771293361, 0.04195685829516484, 0.05082983031597753, 0.05478045395579029, 0.04561005168440439, 0.010058799151769562, -0.05578499989128355, -0.10838935395049118, -0.0856186831942853, -0.01624434749541645, 0.03461276635586358, 0.05370272521950142, 0.05392321069547177, 0.046255028199066024, 0.0362218619475514, 0.026163213785645213, 0.016963755562741357, 0.008878835091809447, 0.0019299630373325095, -0.003970665622077228, -0.008916228783703354, -0.013017303584517368, -0.016354090920544566, -0.019008291922656777, -0.021032340656301617, -0.022476539322617135, -0.023369080382073824, -0.02373288739847211, -0.02357610742157605, -0.022896803291084816, -0.021683881619124202, -0.01991397299004144, -0.017548612250272853, -0.01454547614679618, -0.010835674623144978, -0.006351254335920266, -0.000995338216265178, 0.005319178850037957, 0.012687428518615288, 0.02113874488947553, 0.030575425171667135, 0.04051981611781407, 0.0496349220743119, 0.05459077373523163, 0.04803419183604194, 0.017266748759051133, -0.0451255543502094, -0.10445345226809685, -0.09382689999476237, -0.026343627649616827, 0.02951557296519148, 0.05264340898855374, 0.05472280544526731, 0.04765994531937303, 0.03769074558643306, 0.027500415409006645, 0.018107804185569103, 0.009838045801079469, 0.002713970091542699, -0.0033282073203844014, -0.008402740711279255, -0.012604328833740257, -0.016033979435559283, -0.01876131388137974, -0.020854815525868722, -0.02235664885355595, -0.023305403226327237, -0.02372077172944983, -0.023615224020030575, -0.02298938771276598, -0.021831172093424705, -0.020122391442312818, -0.01782540422046344, -0.014900619865485444, -0.011281939840628401, -0.006904916986104557, -0.0016749839034083899, 0.004491691434440925, 0.011691432903865897, 0.019961402948003658, 0.029229349698742805, 0.039091156373176464, 0.04838546932615984, 0.054188528709856125, 0.049939774459645224, 0.023658507938632552, -0.03455845100573946, -0.09876269326222205, -0.10065768700332835, -0.03689243344495253, 0.023619282125103263, 0.05115037697029299, 0.05536381004339074, 0.04903555832497308, 0.039179415435076366, 0.028864984561191617, 0.01928934770034155, 0.010821934745962474, 0.0035256119847471364, -0.0026725773445663214, -0.00786809555255026, -0.012182133741039767, -0.01570142800759207, -0.018509947812025425, -0.020667374448715026, -0.022232499812250794, -0.023236060553511573, -0.023703989581752866, -0.02364991322570562, -0.02307603676607412, -0.021973628011828478, -0.020321904782716626, -0.018096285580256912, -0.01524479455152809, -0.011714777426765792, -0.00744395749745847, -0.0023368568156329245, 0.0036887528206711725, 0.010724374183085476, 0.01881284019129563, 0.027911677757576744, 0.03767448790897921, 0.04709374570062228, 0.05360960688227922, 0.051379187910055586, 0.029260688731181504, -0.02431650958547108, -0.09153493023987687, -0.10585279945111689, -0.04769812372280692, 0.016899733926547523, 0.04915880284327563, 0.055819105300223784, 0.05036979758543504, 0.040680179212437655, 0.03026726803302371, 0.020498736857160256, 0.011836769665169243, 0.004357999205553569, -0.001988740261787425, -0.007323113793352741, -0.011746155312204068, -0.015360893449177373, -0.018244880008393374, -0.020476682155106746, -0.022101884350369107, -0.023161824099228942, -0.023681917050726174, -0.02367981612379197, -0.02315771363120422, -0.022108462099311155, -0.020519501508844473, -0.018355563952248657, -0.015580167240682558, -0.012136575727739926, -0.007967030626100938, -0.0029790760433027635, 0.0029075136857194, 0.009782654967168947, 0.01769606788225503, 0.0266229230658257, 0.036272356372459895, 0.045771999569878354, 0.052879504524428925, 0.05240924922202618, 0.03410234399307332, -0.014564038500328873, -0.08305492918554294, -0.10921053617637562, -0.05850983334944438, 0.009334568601989305, 0.046615960006907334, 0.056050768315097185, 0.05164877890216536, 0.04219343530178758, 0.031696310160852295, 0.02174391308867621, 0.012881594073817535, 0.005218173135491784, -0.0012887044949048691, -0.006759407401640876, -0.011296719785942905, -0.015008230893415983, -0.017974384686115494, -0.020277317026350498, -0.02196531100092958, -0.023082143848211105, -0.023655322340226614, -0.02370474119564277, -0.02323435564734047, -0.022239033106176403, -0.020707998042545223, -0.01860848477961709, -0.015906039067649785, -0.012546771133240427, -0.00847574264920083, -0.004404714165161455])

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l, pred_waveform)
    push!(losses, l)
    @printf "Training %10s Iteration: %5d %10s Loss: %.10f\n" "" length(losses) "" l
    return false
end
callback (generic function with 1 method)

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob, BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback, maxiters=1000)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-3.7090494515651475e-6; -0.00011290690599682988; 9.121011680684222e-5; 0.00013011461123796781; -0.00019051758863482108; 9.79100077528666e-5; -0.00017315523291444517; 8.074661309354402e-5; -8.14383893156077e-5; 4.991976675222993e-5; -0.00011104928853436303; -3.756931982927671e-5; -8.545671880706262e-5; 4.145344064454085e-5; 2.531830068615997e-5; -3.324981298646107e-5; -3.296637441962647e-5; -4.547579010247641e-6; -0.00010109940922101978; -0.00010161675891126477; -4.226428984560484e-7; 6.633573320853808e-6; -0.0001256601535713415; -8.305306982934979e-5; -0.00019378383876706676; 0.00016258833056755661; 0.0001227200264112042; 0.00011263116903142533; 0.00011468410229999117; 0.0001435890881108357; 0.00011572713265185692; -8.055818761933531e-5;;], bias = [-8.04396505638093e-18; -2.742380373271208e-16; 7.101671137277727e-17; 3.04228217159861e-16; -1.6746315433493231e-16; 8.26733608056771e-17; -3.1127009608839165e-16; 1.200916140177227e-16; -1.4223703747999006e-16; 5.770766282772921e-17; -2.1903586998041962e-16; -7.605963566008613e-17; -5.206070676144167e-17; 6.001497969434356e-18; 3.80245781411334e-17; -8.357245536917506e-18; -3.3808515391286033e-18; -2.6361726038216e-18; -2.07404339561081e-16; -1.5357930954955266e-16; -6.174442794414408e-19; 1.6640660141639814e-17; -1.9177392245093457e-16; -1.3883737703627887e-16; -5.386309483623671e-16; 2.8528184025587146e-16; -9.693933383919994e-17; -5.4473856591060884e-17; 1.0444366566717535e-16; 3.870997897646192e-16; 1.102753118853412e-16; 1.771475695159801e-17;;]), layer_3 = (weight = [-8.682312636807263e-5 1.1948086465476854e-5 0.00010715425287203825 -8.30886319831977e-5 -0.00011259806069709842 -9.309968027222706e-5 6.22483448737509e-5 -6.770233176352823e-5 -0.00016841093474936092 -5.402527404071942e-5 2.606644647900249e-5 -2.8471332372244503e-5 0.00010106295212467908 -3.06935953645347e-5 -0.00011712393188787214 -0.00010991986073491596 8.987223827574897e-5 -8.041623823515771e-5 1.3640645861048528e-6 -0.0001381466763646495 -0.0001963007134705194 -8.227509988642925e-5 -0.00013082937675929225 -0.00020372079325311186 -9.399501596043138e-5 -8.943181910718292e-5 0.00017659268178110278 -9.356008378322509e-6 -0.00010592479248001322 -1.112787049489201e-5 -4.9822710026577475e-5 0.00011160221310861034; 0.00010110425492663232 -3.2136700693223995e-5 -0.00011878254016714942 7.59901521718687e-5 -0.00021103788466515658 -7.268597991292367e-5 -1.529806095805882e-6 1.5434952646912258e-5 0.00016244281133317764 0.00022393016834890446 0.00011216117846692558 0.00020664029571840656 1.5850558983812216e-5 1.3238554033330369e-6 0.00010076631579928314 4.9836791649304035e-5 -0.00019492533473616309 -5.4143525713278756e-5 -0.00011022040424955162 5.1731945777143305e-5 7.402629117628858e-5 0.00018897306282220699 -0.00016935817265272012 -1.8283621184808508e-5 4.624997373812961e-5 -6.008817953495614e-5 6.350463800976277e-5 4.849896496074804e-5 3.362240377959334e-5 0.0002840781670869837 0.00011443564281711793 0.00012609678370559052; -5.7808042744176965e-5 -0.0001388479496892604 -0.00015048543643951903 0.00018329062079264967 0.0001643469812041016 -0.00010018725804566 -6.262362801517568e-5 7.675914115500973e-6 -0.0001042781943180642 0.00021736612763907445 8.16751933814897e-5 0.0001424971786253126 0.0002090233982233285 8.791630609612626e-5 7.207452365383251e-5 1.4184084409908235e-6 -0.00012735225680098213 -0.00014769507759064676 -0.00015657899273373733 5.980084535967944e-5 5.076105737662638e-5 2.776066760252924e-5 -0.00013477278769296377 -0.00011382584414259042 -5.0482466737054905e-5 1.6778791927023725e-5 2.7352910204900056e-5 -8.398113930179424e-5 -1.1924279693811466e-5 0.0002690395275056276 9.235338877606499e-5 -2.7453658891496833e-5; 1.4756263320424754e-5 -8.426040485271417e-5 -2.7081558952623985e-6 3.696960730242257e-5 6.82189721404211e-5 -6.842210124628147e-7 0.00010850698991209538 9.134398815452134e-5 0.0001615254305791763 0.0001330839305440629 -2.1247699711850514e-5 -1.4763880150593594e-5 0.00013471111115298228 -6.522852972618153e-5 8.029147384452816e-5 -0.00014950163212840273 -6.591952742080037e-5 2.7519893686377602e-5 -5.394120935004296e-5 -0.000155859625826153 -4.351788347473164e-5 -1.689868069374867e-7 -6.479976852163873e-6 4.113346322203991e-5 -6.607190780112468e-5 3.444789136058091e-5 0.00013884404425273418 -0.00011263850219317875 2.0686814517218716e-5 -4.037311461320342e-6 -3.116359848415663e-5 2.0009815765051326e-5; -0.00034580299255832555 -5.568215999669835e-6 -4.384923089456536e-5 -9.242645613776558e-5 -4.94055923241692e-5 6.993541024930154e-5 -6.012101151790399e-5 -1.3861719014685874e-5 6.929291408813635e-5 3.4809641077896906e-6 3.1556920259749224e-5 -3.591496646332896e-5 4.59028895078426e-5 9.437426048169527e-5 0.00016551631121036197 -4.651256607541249e-5 -0.00023693419210541432 0.00011365593378503356 5.248335288560378e-5 -1.9570693655266466e-5 -6.763702843965484e-5 1.7315011757440087e-5 0.00012141591720355547 -8.476257356743232e-5 -7.750410007937137e-5 3.476254711350215e-5 -8.434233608350342e-5 1.4541795064719372e-5 -8.236264260413463e-5 -3.986724481146526e-5 -2.8026884354044143e-5 -6.839901765271025e-5; 0.00010395090708787224 0.00016245344073555248 6.633543585936956e-5 5.2158157458990034e-6 8.880675073753486e-5 0.00017759224925699885 0.00025375620231073647 -0.00010279843348965322 6.108322037195499e-5 3.955121484045287e-6 -4.697608614063947e-5 -4.7928458060618095e-5 -0.00021001148792405057 1.8182282529756422e-5 0.00017734598719560128 0.00011340013503374988 4.400224048456634e-5 5.065375358842914e-5 1.4447185149913846e-5 -5.798202340796445e-5 0.00015236570385330025 9.696335388159978e-5 -0.00018106238798195156 -6.265442147108121e-5 5.794929037825548e-7 -2.7244547609925178e-6 0.00013374575706707749 -1.579315625262153e-5 1.2373358059411378e-5 0.00011119861943456965 1.785454247597803e-5 -3.484726124864342e-5; -0.00021086803619887562 0.0002116988892094369 8.280557656056897e-5 0.0001379720181585406 -8.51023521381408e-5 -8.809731091510849e-5 -7.043265370791247e-5 -0.00011623064504143044 1.2138429404369406e-5 8.030100864784323e-6 7.53713654652459e-6 0.00016929206164881637 -0.00010542712855290853 -0.00016996007045868747 8.366275712709965e-5 5.777848698293857e-5 -0.00014327242545455625 8.027357828764972e-6 1.0585692034877458e-6 7.147547719163441e-5 -0.0002573005201401922 5.774747321360863e-5 1.163862568595267e-5 7.656623190363656e-5 -1.54463411736053e-5 0.0002152511717877421 5.355852251617559e-5 0.00022177302005096778 -6.126937089626265e-5 3.455684988969988e-5 0.00014511984615913997 4.897473834084151e-5; 8.175444110850667e-5 0.0001565767221301251 0.00015745308212092527 4.652202789759215e-5 0.0002479474120456065 4.982829829041828e-7 -4.845564888354161e-5 -4.8520899671413875e-5 -6.745274857064242e-5 -8.765678774477241e-5 -9.049412745196193e-5 2.8764511281342995e-5 -2.60664650148762e-5 5.446994566330465e-5 1.6821838891073422e-6 -0.00012612269472708182 9.831484805648029e-5 -6.638994216597775e-5 0.00013596889641223763 9.778103287419583e-5 -0.0001924588954160469 0.0001639710304708672 3.5046227477435105e-5 1.94883929190178e-5 -1.4541841100153136e-5 0.00013316471324389273 -0.0002442271155622411 8.752067389752877e-5 -0.00013352964868215705 7.5103231403045325e-6 6.136308968154842e-5 -6.82254916490567e-5; -2.5452205198614468e-5 0.00010519503342183144 0.00014386797050209364 0.00011789577041016905 -0.00015945801493981018 -4.616891380320601e-5 -0.00020111156260864245 7.873765792427158e-5 -0.0001426030990554994 -3.8197956717718815e-5 -0.00011029151716539818 1.072923662976181e-5 8.299040421805881e-5 7.036760001172611e-5 1.7816921564720308e-5 0.0001227359702457608 -8.001289604036797e-5 -1.078782822188928e-5 -5.6475016725392004e-5 -0.00010031873313747611 5.289698168015332e-5 1.0884327304286358e-5 -1.012256922732374e-5 -8.578351623183528e-5 -8.716060762562828e-5 -0.00010163482290263802 2.4195726320239997e-5 1.5007238681927059e-5 -2.3866983218259706e-5 4.2443384183199935e-5 -8.136300090737777e-5 0.00012188381736594756; -1.0855958360419855e-5 -0.00012918159377288437 -8.178599132218195e-5 -4.995305102776749e-5 -1.6104029990860971e-6 -2.4465817157692942e-5 -6.486978277089861e-5 0.0001185378783266717 -0.00016702037141651789 0.00010656486806698791 -0.00012960509816177586 1.8496107761227507e-5 -0.00021187502029345253 7.260082161525847e-5 -0.00014741125819267212 0.00016342316626111577 -0.00014089720975601515 3.949130000552604e-5 -0.00014319466613226214 -0.000188651531469008 0.00014324916224001414 -7.692536787356737e-5 3.001605505422115e-5 -0.00012055237352608622 0.00012687949452092713 -4.096551533430968e-5 0.0001622728519142402 0.00019549018866389623 4.041470633826822e-5 -2.496522797438488e-5 -0.0001286917180986361 1.376469119415755e-6; -7.837833094400808e-5 -0.00011238582213863658 6.497655864350062e-5 4.2292774755050595e-6 0.00013774461127598746 3.925990109692037e-5 -0.00012331386664170586 -0.00021519733694604307 -2.1021096689404686e-5 4.308894655095241e-5 -5.189657007449283e-5 -0.00010655194477168209 6.310365439403251e-5 -8.770451517067179e-5 3.29429256102125e-5 -3.549356391759921e-5 0.00013565337009039375 6.210184220788178e-6 -0.00016339916895831224 -9.515489931673017e-5 0.00013962240496892962 1.3228099088689915e-5 -0.00017753411360771912 -4.927983740217651e-5 0.00018200533828479355 -0.00011154360094094872 -7.816948913259095e-5 6.0307012755769006e-5 -4.026485679659875e-5 -0.0001311990401869356 -0.00015108365026989809 -2.906141729782049e-5; -6.888998380949434e-5 -5.282241748850019e-5 9.792440483957527e-5 -0.00012979949154746085 -7.932305946801833e-5 -0.00014629533892352182 -7.6939732239999e-5 6.681864557700292e-5 0.00016201463382836486 -4.801223825401426e-5 0.0002518260423713908 4.026676266599685e-5 2.887945789930909e-5 3.851177622360405e-5 2.9500319000508492e-5 -0.00013855903147108609 0.0001007538210289205 5.7137820262287806e-5 6.101693521864643e-5 -3.775620483415275e-5 -0.00022723696223314564 -6.0687087385158e-5 6.915611972013229e-5 9.998069232475059e-5 -0.00019921311671701345 0.00012404737732517833 9.359627951339305e-6 -2.7578669901999327e-5 0.00015752762352727183 0.00017333501979387871 -7.461703022131164e-6 1.3804813319582635e-5; -2.8117213340067263e-5 -0.00015431723387159056 4.72005076444027e-5 -0.00010623734215917975 -7.009397905494997e-5 -8.713761872790859e-5 -6.987016332278916e-5 -0.00011344629547970915 -2.141788082376654e-5 0.00013285290969355766 -7.60150951537461e-5 -4.18566533044761e-5 -0.00022417732635226538 -5.102708208196159e-6 0.00011156943855161313 -5.5480691124158226e-5 -8.196478393807476e-5 -2.9136080236696583e-5 5.098852761143804e-5 2.1990582346874307e-5 4.6282416398696626e-5 4.677419473586055e-5 -0.00018168152838192087 -0.0002613615734008377 -0.00016970553497958724 9.117655189764218e-5 -6.671863321393045e-5 4.452824486357501e-5 -8.717175752102919e-5 -2.47716225276506e-5 0.00017731017486945244 -2.7759229015392956e-6; 4.101012015224618e-5 0.00023882875390325513 7.931204160184556e-5 3.1108622319043636e-5 -4.124062619314401e-5 -0.00010258686237295236 -6.890432272132185e-5 2.6576726635433797e-5 5.832165694653991e-5 7.297876433211493e-5 -8.027161678557439e-5 5.941111882192229e-5 -0.0001820025500238595 0.00010147059216604702 -3.280995227770713e-5 0.00014192765226097 0.00012037649136244471 4.470044947437194e-5 2.865802886288519e-5 3.0012402898030317e-5 -0.00010761979504979034 7.068838018666202e-5 -0.00024256452164638103 4.0536233394853736e-5 -0.00026646765767156987 0.00015984801572256185 -0.00010597095391505624 -8.260218604493118e-5 4.4651813335696356e-5 0.0001246317696894761 -9.805073208200927e-5 0.00011118358812735423; 6.151760809340397e-5 5.714330965170008e-5 0.00019483475390339141 -0.0001060765813129878 -6.92479175890008e-5 0.00015613429915142907 -7.37370088139621e-5 -0.00012753964491670874 8.804592118082285e-6 0.00010866836001847475 7.697989113844352e-5 -1.1423456806675955e-5 -5.061730542678964e-6 -0.00010944407639575502 5.857057968720605e-5 6.55093730635265e-5 1.8557332997864525e-5 1.8556044430303032e-6 0.00012015030862280756 5.9558894837802464e-5 7.45815681617975e-5 5.622588415614208e-5 -0.00010293728208885571 3.529879305611399e-5 0.00025098276583923303 0.00010666470681065377 -9.014932027961648e-6 -0.00011004690403601774 9.56718699050255e-5 -3.0356609011833855e-5 7.978844715730519e-5 -1.6405460508707095e-5; 8.194485955227575e-5 -2.8630584061657154e-5 -3.607416679921927e-6 0.0001256449230951938 -2.938469705061499e-5 -9.885128825597452e-6 9.03305098419678e-5 -2.1388987157302864e-5 -3.9239277690108885e-5 -8.810033076388507e-5 7.50992530030908e-5 2.805160867337459e-5 2.6748135418705206e-5 -8.651725701002101e-5 1.2968737714070795e-5 -8.789712781963529e-5 -1.4150496177355603e-5 0.0001253794961614389 0.000134673906281197 0.00010448152708176295 2.431432213979716e-5 8.600029642748121e-5 0.00017886778179143918 -0.00013943908023593963 0.00011804887427767682 1.381145731482357e-5 -0.00021141764053555733 3.475853183243977e-5 -3.6227569658704435e-5 -2.241140835923546e-5 3.062861461711627e-5 2.546892930167268e-5; 7.216422120967182e-5 -0.00012900384869732383 5.694772001282727e-5 5.6412289567959e-5 -8.47918697268842e-6 4.7447407568539277e-5 4.0836305133271514e-5 -4.448759106759538e-5 2.4397728151918293e-6 -7.443520705288917e-5 -0.00014937728671662938 -0.00014152791088287451 7.888755434974304e-6 -6.7889874758893474e-6 -0.00012269840196546848 5.4566779298906004e-5 -7.286337465353984e-5 0.00011986798956916902 0.00017193438573434475 6.267012970914591e-5 -9.644335629724664e-5 -8.713960384790551e-5 2.5220790420860243e-5 5.1191269748810804e-5 -4.432520624555527e-5 4.028368160055829e-5 8.364861453577506e-5 8.941102017109792e-5 -4.685116770866739e-5 -1.865854180374432e-5 3.0291468946813275e-5 -0.0001470606363641899; 5.317333127704375e-6 0.00010540436833494675 -0.00015732164992854247 6.241608989753762e-5 1.410119839659987e-5 0.00026227192629690865 -7.606789925891008e-7 -6.955438400233277e-6 -0.000112782239191198 -6.873063048017927e-5 -4.151564589605301e-5 1.1689011992228431e-5 0.0001032853693710996 2.4201294402399067e-5 7.212274849977706e-5 -6.387769044279727e-5 -2.2706455090275488e-5 -5.9626417530426365e-5 -8.752080557072933e-5 0.0002009330861280073 0.00013447109868381586 -7.01949096740741e-5 -7.42863043318214e-5 6.0527804273682225e-5 8.812901851582568e-5 -0.00018491037862745867 4.468465758932903e-5 0.00022694064737393857 -7.845402378492687e-5 0.0002072160792546597 7.857511811978674e-6 4.547929038623246e-5; -2.3371270524988482e-5 2.8238726646921043e-5 8.44364806670249e-5 -0.00022490110468254172 2.7433189205835516e-6 3.4001750738641394e-5 0.0001388808998871164 5.008878635523772e-6 -5.5447216889772634e-5 -0.00012550702713099153 0.00011159712112373892 -7.512433848944351e-5 0.0001018558653430566 -0.00017395802178555021 -0.0001294802384091993 -6.017012465065206e-5 5.486322881439939e-6 -6.094856844185719e-5 5.9640788369395026e-5 -3.4008560485856413e-5 1.2201673873166016e-5 -7.908332527051548e-5 -4.507881358051427e-5 8.376656142161411e-5 8.019866463826442e-5 -1.4101502254671919e-5 0.00016897765117562615 8.049699345235934e-5 1.75017578857275e-5 -4.232156761474627e-5 -4.237157527142906e-5 0.00011318007846228026; 0.00025280462485992175 -2.7144811075201555e-5 -1.2697790599274123e-5 -6.978480033555299e-5 2.2493166933629442e-5 8.28002308381049e-5 -0.00014780212169259135 -1.9762082421008567e-5 -2.5004200698246325e-5 -4.958063019418336e-5 -0.00010149691607781901 5.027903588509944e-5 2.0691561611814324e-5 9.941370134495953e-5 2.132208700371268e-5 0.00021658304521136735 -0.00011038197354625153 5.714887137313204e-5 3.0194247837274133e-5 2.9620174781515054e-5 -0.00010475763648261486 -2.4041486004504218e-5 -0.00020974771851885364 0.0001317601496239452 -1.5962151711362678e-5 -0.00012283589550039997 -2.9096059198279315e-5 1.256130043687467e-6 7.785130636787756e-5 3.1583386397899475e-5 -3.9036181452472084e-5 1.3245511208000052e-5; 4.03843634210519e-5 -6.687324170182557e-5 0.0002311988111266796 -3.578420780609607e-7 -2.3726174584454792e-5 -0.00010336362432847594 -5.89623476464533e-5 0.00016770238759070073 -0.00012970165229328546 0.00012703237476832904 8.203200301228956e-5 3.875677899648252e-5 8.418371327727389e-6 9.104529195772275e-5 0.00018846865456160033 -6.473941433634772e-5 3.905923327854675e-5 2.6785598640143074e-5 -9.05666169112924e-6 -0.00013703433327303188 -2.8687557313079798e-5 9.170768058700407e-5 9.643815170581128e-5 0.00010136981760458428 -0.0001867003400896033 0.00012814041580145986 5.723505924994788e-5 -6.984794421280828e-5 5.0431496821845524e-5 0.00017201688493674865 -6.945490425844823e-5 -0.00019872775996029936; 0.00011840403151405387 -0.0001828145455644952 -8.084026899037557e-6 -0.00011402630350380589 -1.969028614872417e-5 -0.00014649389647706482 0.0001246696933096646 8.758127789883937e-5 0.0003080948676372419 -6.514479043720985e-5 0.00012747604471337795 -9.889029622603857e-5 -8.197648874078847e-5 -0.0001802024185733377 -9.971554989055414e-5 -4.988627058514099e-5 0.00010313284421977884 -4.1506157299174315e-5 0.0001420210768858427 5.0490523267629496e-5 -2.7595785081522835e-5 0.00012924686727752317 0.00012679291960489603 -3.4662700795474935e-5 0.00017229719508144665 -7.525969574522087e-5 1.423692031062371e-5 5.0563515674413974e-5 9.15675495209616e-5 -0.00013004797672215454 0.00013618190262173072 -7.448479898334417e-5; -3.928481679148389e-6 2.4765382672054584e-5 0.00020227119986362138 0.00016161247932744852 6.50625517793741e-5 5.248968610817657e-5 -7.415509734145203e-5 2.044422964377968e-5 -3.7334821497476475e-6 5.991737654245887e-5 0.00010478863279263088 5.811912725609418e-5 6.0253027381128954e-5 7.866336110443443e-5 5.3229334493398815e-5 0.00010895800388813696 3.866580319870012e-5 9.668357253237567e-5 0.00022623396963740674 -0.00010544703380742152 0.00013062369563546903 -8.104755378165345e-5 6.916483666136887e-6 -5.792716163863789e-5 0.0002441646940625991 0.00017145650073345617 -9.872598437664991e-5 -2.8012607604754154e-5 -7.812423408336554e-5 -0.00011074043121495308 -3.944600653771879e-6 0.0001216891398273752; 0.00024545705838483096 -0.00012564323532040155 0.00014229300160315642 -0.00021461520208156356 -0.00012228591930230586 -6.982438043103661e-5 -8.557538591200547e-6 -0.00011408249537105954 0.00012749967010114996 -3.6507213904937276e-5 7.783882921450822e-5 -1.3517584362657455e-5 -9.158455048798824e-5 3.517126043681427e-5 4.87677805847985e-5 0.00010928823923115024 -2.3203765691673756e-5 -6.533309186687418e-5 -3.487982593122009e-5 3.334349259041185e-5 -4.559850475396688e-5 9.81709106907276e-5 -0.0002564305233919875 9.271504114969371e-5 -0.00027783659441971587 -0.00016362594595754586 3.120652928721367e-5 -7.096743337222306e-5 -6.417564405799097e-6 -1.9857481844312832e-5 0.00024786371233296694 8.742738658673536e-6; 7.211735411200407e-5 0.00012301484585858075 -8.251989728652561e-5 -0.00012632359229375662 -0.00013772306791854963 4.9633012264014915e-6 8.970177205862802e-5 -6.865719714316151e-5 -1.1475020609773578e-5 3.758975563922526e-5 0.00015295912217932895 1.792805906969729e-5 0.00011865614707001839 -0.0001923889236326549 -1.302071049341293e-5 -6.496162395171916e-6 0.0001462551275932013 1.4176683553694417e-5 0.00016035046192935902 -0.000158112047383308 9.880261987799515e-5 5.543011267089949e-5 0.0003260496863827761 -0.00011895149017565513 -5.048598776686936e-5 1.6047118525034823e-5 -7.214134765437777e-5 -4.374667204448828e-5 0.00011447586203810867 3.083775792272552e-5 -6.360280039136201e-5 -8.787731584508514e-5; -0.0001571842772772785 -0.00013256722429098102 0.0001039035809572608 -0.00013066601656639134 -5.3810967592562945e-5 9.810986687632465e-5 -0.00010622443784987481 -9.50535728133477e-5 -0.00011676823175542119 7.25418753336887e-5 -9.553390243120066e-5 4.016280893671746e-5 -5.982460840635007e-5 -7.161769151019416e-5 8.158915027712826e-5 -1.3322938293580213e-5 5.956445550515709e-5 0.00010057161437549931 3.767220135166554e-5 -9.392518092674884e-5 -1.5192399196107202e-5 -6.896910469543816e-5 -7.532191063539338e-5 -8.952581681108143e-5 -1.739501619331284e-5 0.0001647535495340801 -8.851935542217728e-5 -1.68715492417711e-5 9.545868982065963e-5 -0.00013906453802511824 1.0256455816559523e-5 7.484956898230985e-5; 7.835838530891554e-5 -5.1449897199824744e-5 -0.00022476292016935075 -1.0846542153635258e-5 2.2935022371646327e-5 0.00014649000857472024 -1.3961288658789446e-5 -6.840860825749689e-5 -0.00017202234104760563 -0.0003730370846998027 7.965790956325299e-6 -0.00011042207572163023 5.788560864030784e-5 -4.8329573673158115e-5 -6.210752710021084e-5 0.00016351836873664566 -9.53259043118074e-5 6.368836039568786e-6 -6.787407275541115e-5 7.693245407527952e-6 1.2405004582346659e-5 -7.571463012854587e-5 -8.513655140548985e-5 0.00014600513875930946 -7.038544228550688e-5 6.644123985181211e-5 5.7980960064846196e-5 3.9305563205664645e-5 8.416806284086586e-5 1.7832791655489787e-5 1.0703107164693615e-5 0.00014045691548827458; -5.11697754071151e-5 0.00011812722454961081 -1.906515449336521e-5 0.0001288339853537836 2.289385970400685e-5 -1.6261122301147764e-5 -0.00018586742139703217 -4.9325360169680396e-5 -2.2382494581327258e-5 -8.992127463971603e-5 -6.714597752834085e-5 -6.737113931267961e-5 -0.0001761746671498276 -2.6533693972676428e-5 3.0286287209665882e-5 0.00011638062364835756 3.0795076735732114e-5 -4.905123846657211e-5 2.709712221003526e-5 0.00010154844580865862 3.760288832308154e-5 -0.00013093187503801293 -9.892195435058347e-5 -9.242824221546686e-5 -0.00013285159279873435 -5.978775066140063e-5 8.366502965982914e-5 -6.463483355291502e-5 -3.19728706479472e-5 -3.655002642589191e-5 5.671940102657058e-5 -3.947787904581018e-5; -5.559961980298617e-5 -8.206012396232453e-5 -2.087061578379784e-6 8.910317003636162e-5 -3.63150252026845e-5 -0.0001012418619878316 -1.4692649959512226e-5 8.534403925904189e-5 -7.030972162584665e-5 1.121277957890055e-5 2.5202468829488188e-5 0.00010802562203633997 4.833877349182104e-5 -9.586061854232952e-6 -9.205988724107414e-5 -0.0001378986312132008 -1.9938709077276792e-5 6.26088981577522e-5 0.00013002549851688465 0.00019311214012003404 7.303952909450522e-5 0.00014452048665288382 -8.376933371738927e-5 7.13802615125021e-5 -4.041645704394874e-5 3.355561014057083e-5 5.9872483258625654e-5 -5.2185049274444973e-5 8.708228736093949e-5 0.00010186247120753197 4.1816223098688904e-5 -0.00011126064461978966; 6.606182111374434e-5 -6.150417496735132e-5 8.210602797282716e-5 -9.016683501010433e-5 -0.00011009673792919611 0.000155087098816353 -0.00011484406744181061 -7.420248178584744e-5 0.00010322915883682432 3.9115138557658e-6 -6.391803938947541e-5 -3.731550853752126e-5 5.960670429578781e-5 4.605705695777003e-6 -0.00023805854610452998 2.6620079067125147e-5 2.6013547964449193e-5 -9.111131251609552e-5 -5.920053038914729e-5 9.696965632089067e-6 5.750522947584606e-5 2.2874103221165524e-5 -0.00019011499128589141 7.204894469998554e-5 9.717163872143531e-7 3.7245631089386675e-5 -7.201051224307864e-5 -3.499570405744853e-5 6.4678779077693346e-6 -4.234912515436697e-5 4.966138700786181e-5 5.4964410507294686e-5; 5.038606177841002e-5 0.00011847440037235761 3.382759862868877e-6 -6.851589132275445e-5 -1.257785170020108e-5 3.678867942944429e-5 -3.826391727498614e-6 7.482040091708678e-5 -3.992060734891338e-6 -5.437573708768407e-5 -3.0123693429053103e-5 2.6640912258882163e-5 4.287150004820313e-5 -0.00018274801841180213 -0.00016184561827394515 9.608550026605484e-5 1.2336807961402198e-5 0.00023282153124746553 -7.777062089010093e-6 0.0001190340088243806 0.00011334951414319376 5.033474081137836e-5 1.4977032881866042e-5 -0.00012922367313585815 -2.838118707296756e-5 6.817278045054159e-5 -0.0001958967652142788 -0.00011989748890919466 -3.497654615494295e-5 6.930831732751654e-5 0.00016328993370115876 4.393733688921114e-5; 3.4434913861529306e-5 0.0001142032271602179 8.366136916466346e-5 7.692125672492618e-5 0.0001273522722769298 8.627102233022622e-5 5.0090577308024766e-5 3.2116131653501545e-5 -5.088504604626933e-5 -9.776548523203574e-5 -0.00012338649667780933 -4.835798227557916e-5 -0.00010277669917540241 0.00011947179172440401 -6.271167495268325e-5 3.797956032232407e-5 -7.302270369817474e-5 -0.0001338024665599583 5.904471999402271e-5 5.4000767879375345e-5 0.00014882325939853807 0.00020541101566069885 9.261295592478154e-6 -9.995292912916791e-5 0.00011781106895084867 5.24974495348847e-5 2.2333686676757757e-5 -2.589585875084275e-5 3.0943145422682426e-5 -7.284875944428291e-6 6.416631385236797e-5 5.2213501651030885e-5], bias = [-5.708899662680442e-9; 3.943486682317774e-9; 9.247958636517974e-10; 8.886703030924312e-10; -2.454942311040458e-9; 4.124613931843463e-9; 1.4972766128303122e-9; 1.5485248980697946e-9; -1.314736816643697e-9; -2.5833723463850055e-9; -3.5272374701761183e-9; 1.0748082428071232e-9; -4.402871184029379e-9; 8.439782373357593e-10; 3.468277704680096e-9; 1.911679827040112e-9; -6.690884823769753e-10; 2.4379939254684797e-9; -2.2164215201145976e-10; 1.096750871941204e-10; 2.3184491313623247e-9; 1.3109686353049884e-9; 5.5590264959292115e-9; -1.7664080458355945e-9; 1.3073173138814822e-9; -2.638052987661061e-9; -1.7937228178960856e-9; -2.9095197619190105e-9; 1.6912296741335406e-9; -1.6985117145625505e-9; 1.1466938319198792e-9; 2.8450742469905996e-9;;]), layer_4 = (weight = [-0.0005323928768761613 -0.0007894619577301077 -0.0006543585767488301 -0.0007034991772171813 -0.0007600957375610024 -0.0007851951901347389 -0.0007313171680350887 -0.000661682646145613 -0.0006701723184385237 -0.0006173697381690178 -0.0007825279990756601 -0.0006805750847731156 -0.0007945250188785071 -0.0005216925298062381 -0.0006385653360200036 -0.0006543465991457345 -0.0007187391859466343 -0.0007565943581284801 -0.0006546388038430187 -0.0008505209368499837 -0.0005410664967576965 -0.0007871846440582564 -0.0006873682310262982 -0.0006221848776228572 -0.000721120729433857 -0.0007662559916596304 -0.0006184200494979283 -0.0006850955040645564 -0.0008481186355750817 -0.0005991206515183062 -0.0006789552103689989 -0.0006936197592131213; 0.00020780268683895123 0.00022301495994640015 0.00034588154650317684 0.00023196114374237102 0.0001266036288086739 0.0002603136277463049 0.00011361084991039054 0.00011226089709543675 0.00019438376241555163 0.0001076105685714001 0.00013582891030860467 0.00020825471051468122 0.00037229344246301793 0.00023929970462951706 0.000497038722457468 0.00024146280792742403 0.00026658138230789823 0.00022261879263402921 0.00018503784396299938 0.0003808090307874895 0.00017932229704676955 0.00032775046406881096 2.5168810238514357e-5 0.00039239134210360116 0.00026701355709693935 0.0001297956058632726 0.0001710948701523308 2.2075604601449307e-5 0.00022185534322006394 0.0003301142402265051 0.0003692428965310408 0.0003636883630643871], bias = [-0.0006844514974276751; 0.00022316301094391861;;]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; markershape=:circle,
        markersize=12, markeralpha=0.25, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(compute_waveform(
    dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(ax, tsteps, waveform; markershape=:circle,
        markeralpha=0.25, alpha=0.5, strokewidth=2, markersize=12)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(ax, tsteps, waveform_nn; markershape=:circle,
        markeralpha=0.25, alpha=0.5, strokewidth=2, markersize=12)

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(ax, tsteps, waveform_nn_trained; markershape=:circle,
        markeralpha=0.25, alpha=0.5, strokewidth=2, markersize=12)

    axislegend(ax, [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb)

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(LuxCUDA) && CUDA.functional(); println(); CUDA.versioninfo(); end
if @isdefined(LuxAMDGPU) && LuxAMDGPU.functional(); println(); AMDGPU.versioninfo(); end
Julia Version 1.10.2
Commit bd47eca2c8a (2024-03-01 10:14 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PROJECT = /var/lib/buildkite-agent/builds/gpuci-8/julialang/lux-dot-jl/docs
  JULIA_AMDGPU_LOGGING_ENABLED = true
  JULIA_DEBUG = Literate
  JULIA_CPU_THREADS = 2
  JULIA_NUM_THREADS = 48
  JULIA_LOAD_PATH = @:@v#.#:@stdlib
  JULIA_CUDA_HARD_MEMORY_LIMIT = 25%

CUDA runtime 12.3, artifact installation
CUDA driver 12.4
NVIDIA driver 550.54.15

CUDA libraries: 
- CUBLAS: 12.3.4
- CURAND: 10.3.4
- CUFFT: 11.0.12
- CUSOLVER: 11.5.4
- CUSPARSE: 12.2.0
- CUPTI: 21.0.0
- NVML: 12.0.0+550.54.15

Julia packages: 
- CUDA: 5.2.0
- CUDA_Driver_jll: 0.7.0+1
- CUDA_Runtime_jll: 0.11.1+0

Toolchain:
- Julia: 1.10.2
- LLVM: 15.0.7

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 25%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 3.756 GiB / 4.750 GiB available)
┌ Warning: LuxAMDGPU is loaded but the AMDGPU is not functional.
└ @ LuxAMDGPU ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/LuxAMDGPU/sGa0S/src/LuxAMDGPU.jl:19

This page was generated using Literate.jl.