Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux, ComponentArrays, LineSearches, LuxAMDGPU, LuxCUDA, OrdinaryDiffEq, Optimization,
      OptimizationOptimJL, Printf, Random, SciMLSensitivity
using CairoMakie

CUDA.allowscalar(false)

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end
one2two (generic function with 1 method)

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end
soln2orbit (generic function with 2 methods)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end
d_dt (generic function with 1 method)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0)<1e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio1 "mass_ratio must be <= 1"
    @assert mass_ratio0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end
compute_waveform (generic function with 2 methods)

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defiing a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(Base.Fix1(broadcast, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1e-4)),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1e-4)),
    Dense(32 => 2; init_weight=truncated_normal(; std=1e-4)))
ps, st = Lux.setup(Xoshiro(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-0.00019958167; -8.23322f-5; 0.00011193903; 3.972371f-5; 1.2679992f-5; 7.962795f-5; -9.4820105f-5; -1.0310191f-5; -3.0485766f-5; 6.003493f-6; -7.885363f-5; 4.102652f-5; -9.968832f-5; -8.279307f-5; -1.4096135f-5; -7.4919546f-5; 9.036909f-5; -2.222121f-5; 0.00024888132; -1.42023f-5; 0.00017753802; -1.0629531f-5; 2.4538875f-5; 1.4521904f-5; -7.729492f-5; -1.4483704f-5; -0.00011247258; -0.00010900453; 2.9040108f-5; 4.780892f-6; 0.00018319089; 2.0366484f-5;;], bias = Float32[0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0;;]), layer_3 = (weight = Float32[-8.2967206f-5 6.126505f-5 -5.4202836f-5 -6.0297166f-6 4.444496f-5 -0.00010085247 0.00011752079 1.0015898f-5 -4.763423f-5 3.996415f-5 -6.184057f-6 -9.5429874f-5 0.00017989287 2.559291f-5 8.8970075f-5 -5.0847302f-5 -5.309836f-5 4.0173643f-5 0.00010814917 4.8285518f-5 8.153495f-5 -0.00024206469 -2.9612946f-5 -6.846594f-5 -6.516001f-5 -0.00019479581 -2.9676923f-5 -0.00012303567 0.00016065197 -0.00020494669 6.5175955f-5 -0.00011665386; 4.4113836f-5 0.00020232865 4.031389f-5 0.00017390863 -0.00018298297 -2.8041444f-5 0.00022360511 6.0158767f-5 -0.00010068454 -0.00014094544 8.87567f-6 9.5228956f-5 -9.564057f-5 -6.855729f-7 -0.0002483263 0.00010052261 7.148451f-5 -0.00015363927 -0.00011880494 -7.133596f-5 1.5844196f-5 -0.000109300134 -5.648238f-5 5.984945f-5 9.651995f-5 -0.00014640989 -0.00033336596 -1.2647464f-6 1.6209722f-5 9.711934f-6 0.00011689553 -7.268839f-5; -9.453442f-6 2.0036869f-5 0.00011599127 -0.000106741 5.674897f-5 3.1789798f-5 8.190024f-5 -9.972924f-5 5.6343517f-5 -3.495128f-5 -4.1806707f-5 2.0252124f-5 5.9166774f-5 -4.104539f-5 0.0001257722 -0.00014815027 4.2301737f-5 9.630039f-5 -1.2042221f-5 -7.472567f-5 -0.0001228687 6.46884f-5 0.00011598006 7.659546f-5 4.386737f-5 -0.00023943685 1.3806363f-5 7.543245f-5 0.00015396495 -3.9893523f-5 2.9165818f-5 0.00012621297; -0.00018614276 -4.2515483f-5 -0.00012903429 7.3114694f-5 0.00011710013 -1.4751085f-5 -5.653764f-5 -5.1581996f-5 -2.3857521f-5 -1.087404f-5 -7.393113f-5 -8.453979f-7 0.00014439343 0.00017491054 8.902615f-5 5.7722445f-5 8.567603f-5 -2.6166338f-6 0.00015045679 -9.028511f-6 -0.00018741612 -2.6028476f-5 -5.7995145f-5 0.00014572714 -0.00010921581 -0.00012703521 1.11311665f-5 5.4957316f-5 0.00014796825 -7.540946f-5 4.4074848f-5 3.7858572f-5; 0.00019415154 -0.00017560346 7.807468f-5 0.00013261064 7.159899f-5 9.862555f-5 0.00016940045 -4.3855707f-5 6.181105f-5 0.00013408961 7.479411f-5 -1.11049785f-5 -1.630245f-5 7.5183234f-5 0.00016757542 -0.00016842238 3.536953f-5 -0.00013181406 -0.000108641485 -7.675599f-5 1.7185474f-5 0.00018902037 0.00016615262 5.3984422f-5 -5.7781086f-5 4.0363324f-5 6.879808f-5 -0.00021970726 -2.6365584f-5 0.00014945403 -4.456291f-5 0.00010107857; -7.9340134f-5 0.000101121805 1.426412f-5 0.00013615334 -1.49850275f-5 -0.00022095289 -8.6595006f-5 0.00017674715 -7.695296f-5 1.2352462f-5 0.00013914914 0.00012844999 -2.4774909f-6 2.3164013f-5 3.0513422f-5 -0.000103303406 3.224719f-5 0.00010993326 7.576403f-5 -0.0001009203 -0.00014358954 -1.1956023f-5 3.0710973f-5 9.648503f-5 -9.615514f-5 1.9898369f-5 4.9608385f-5 5.1221185f-5 6.211366f-5 -0.00011905759 0.00025028456 -1.3847136f-5; 3.096025f-5 4.2444666f-5 -0.00012679708 -0.000121237565 -1.2929447f-5 4.9318656f-5 4.9963517f-5 3.820712f-5 2.046517f-5 3.3098266f-5 -0.00012352118 0.000109712724 -6.173482f-5 2.4895833f-5 -0.00014110643 -0.00010128689 0.00013700998 -1.7972088f-5 -0.00013546772 -4.238959f-5 -7.421566f-5 5.7370007f-6 -5.4116732f-5 -9.927378f-5 -0.00012151985 7.87254f-5 -9.811379f-5 -1.0532768f-6 0.00012664589 0.00011513158 -1.9218582f-5 0.00016230515; -0.000122027945 6.117091f-5 -0.00012924873 -6.962836f-5 0.00011433479 -1.6616108f-5 0.00019748052 8.237187f-5 -0.00018598307 -0.0001585303 -7.7764846f-5 8.050949f-5 3.1773077f-5 2.974107f-5 7.182804f-5 -0.00018487914 -9.9738545f-6 -1.6996277f-6 8.579841f-5 -7.9025194f-5 -2.6680449f-5 0.00014714779 0.000101376514 -6.7830544f-5 -7.017695f-5 0.00011276063 -2.0974665f-5 -5.7865658f-5 6.860928f-5 0.0001316426 6.941627f-6 -1.7163366f-5; 5.3994034f-5 -5.225011f-5 -0.0002602632 6.164099f-5 8.924189f-5 -1.7354989f-5 0.00026274362 -9.169149f-5 -3.9186347f-5 -4.325795f-5 -3.604186f-5 7.7113305f-5 -0.00017112572 3.9365004f-5 -0.00013173003 -0.0001536866 0.000111204004 1.2391155f-5 8.533798f-5 7.05171f-5 8.091578f-5 7.371852f-5 0.00012445357 8.3389714f-5 -0.00011529438 6.005556f-5 2.743381f-5 8.946204f-5 -0.00011718202 4.031101f-5 1.8349814f-6 6.513537f-5; -9.6204494f-5 -0.00012052375 4.9115304f-5 0.00010186111 -7.7331184f-5 -9.612737f-5 -1.21397015f-5 -0.00017034102 0.00021902032 -4.1375435f-5 -2.9474433f-5 6.866365f-5 5.655771f-6 5.149859f-5 0.00011851082 -4.8823924f-5 -0.00015756459 -0.00015023575 -0.00015709361 -0.00019265419 8.665964f-6 3.301074f-5 6.2747495f-5 4.4142193f-6 -8.882171f-5 5.9867663f-5 9.565558f-5 -0.00012322878 0.0001396682 -1.0528436f-6 1.3017588f-5 -6.830124f-5; 0.00014043535 0.0001458661 -1.0580669f-5 3.415002f-5 -9.094207f-5 -0.000104109815 -0.00017184741 -8.924164f-6 1.1885015f-5 -8.385902f-5 -0.000133324 0.00011912828 -1.1354923f-5 8.160344f-5 0.00013524597 9.634392f-5 -5.7598136f-5 9.013959f-6 0.000107318105 5.336965f-5 -6.3373074f-5 -5.8402744f-5 -6.96892f-5 -9.5489326f-5 2.8842527f-5 -0.00014465171 -2.922795f-5 8.57541f-5 0.00010778785 -0.00012160734 0.00017163868 -0.00022549221; -6.989397f-5 0.00017408725 0.00013634625 -0.000108445856 6.092942f-6 -0.00014905736 -0.000111021196 5.3433712f-5 8.212654f-5 -5.377925f-5 3.6583475f-5 0.00014877653 -0.000103342165 0.00021884569 4.7007223f-5 1.8562456f-6 5.6556608f-5 0.00011660339 0.00011577796 0.00017899163 7.059064f-5 -7.932753f-5 -3.6877773f-5 -0.00016800762 -7.462748f-6 0.00015160255 -1.3513586f-5 7.1694696f-5 -4.3576347f-5 2.5499305f-5 -9.5107f-5 7.2401504f-5; -0.00011767648 3.1720643f-5 0.00023254725 -6.900022f-6 8.976212f-5 9.7895f-6 -7.5759235f-5 4.072951f-6 -4.1957966f-5 3.6865556f-5 3.399748f-5 1.9145322f-5 1.5643536f-5 7.8599085f-5 8.984544f-5 7.834302f-6 0.0001485694 2.6982007f-5 0.00012535158 -0.00014617997 5.068204f-5 2.1158374f-5 3.4197536f-5 7.327326f-5 5.7836125f-5 -5.782797f-5 -3.750994f-5 -6.325415f-5 -1.960383f-5 3.468978f-5 3.6919067f-5 4.111529f-5; -4.8490485f-5 7.770821f-6 -4.4839417f-6 -4.8364036f-5 6.733177f-5 1.5546504f-5 -1.21245f-5 0.00011424743 0.00012973076 0.00010530255 -1.7258866f-5 0.00013159945 4.2534568f-5 1.3018493f-5 0.000373051 -6.301845f-5 -2.6281828f-5 4.78523f-6 -3.2647997f-5 2.0550951f-5 0.00012586171 0.000108761575 1.734099f-5 1.5186557f-5 4.101194f-5 0.00020249262 3.976822f-5 -5.617266f-5 3.3352684f-5 6.8494424f-5 -3.5432186f-5 -4.050125f-5; 9.763546f-5 -5.7528345f-5 5.1951174f-5 2.808614f-5 -0.00014810613 4.5977802f-5 -8.949658f-5 0.00013744613 -8.155766f-5 0.0001325531 0.00010143081 7.595941f-5 5.382325f-5 8.803268f-6 -6.424135f-5 1.2450341f-5 0.00018162804 -0.00011887154 -7.6377255f-5 -0.00014632172 4.791566f-5 6.452713f-5 0.00013130263 -5.270514f-5 -0.00010863828 -1.0926704f-5 9.174357f-5 5.130808f-5 -0.0002238504 0.00012229053 8.130676f-5 2.482176f-5; -0.00022460023 -8.5927626f-5 0.00012520017 0.000106854335 5.516082f-5 6.469363f-5 -0.00017428715 0.000106712556 5.7775986f-5 0.00016063757 -0.00011905362 1.7139606f-5 7.231819f-5 3.261396f-5 4.553367f-6 -1.2987151f-5 -1.5418596f-5 6.43568f-5 -0.00010344612 -4.3443855f-5 -6.226517f-5 -0.00032615496 0.00011252029 9.933419f-6 -5.438312f-5 -0.00013958594 7.4375333f-7 0.00015180545 -1.3335645f-5 0.00011169598 1.22708825f-5 -6.303879f-5; -0.00024773317 0.00010009082 -0.0001319392 -1.1457318f-5 -1.3049577f-5 -0.00010246004 -0.00022768776 -7.770106f-5 0.0002451597 0.00014198746 0.00011474301 5.583269f-5 2.0425464f-6 -0.00012983207 -5.096399f-5 -0.000100110745 -9.638766f-5 5.8334062f-5 -0.00014699885 -0.00013523966 -6.4942076f-5 9.148771f-6 -0.000100656165 8.327f-5 -3.207574f-5 0.00020708723 3.5749294f-5 0.00016056931 -0.00010866737 -0.00010157468 0.0001012336 1.7625822f-6; -9.068424f-5 0.00014179164 -0.00024294283 -7.732556f-5 8.907411f-5 -2.9104913f-5 -2.5391788f-5 -3.5044155f-5 -1.5408028f-5 2.0861247f-5 6.2595886f-5 1.5647702f-6 -1.4038736f-5 4.9628932f-5 0.00012918325 7.1496856f-5 -7.907139f-6 0.00010819498 7.363849f-5 5.060041f-5 -2.4479506f-5 -2.7930431f-5 -7.071934f-5 -0.00016543592 4.2123054f-5 -0.00011032092 -7.761623f-5 0.0001958077 1.6854741f-5 9.59017f-6 7.129312f-6 3.3057426f-5; 3.6966194f-5 5.6543813f-5 -2.207663f-6 -1.0656261f-5 1.4857007f-7 4.354995f-5 -0.00019861762 6.736715f-6 -8.8661494f-5 -1.9353138f-5 9.3636285f-5 -1.8943569f-5 1.3930739f-5 -0.00011097215 7.488149f-5 2.949035f-5 2.1540607f-5 -0.00012140342 -0.00014297049 -9.4322786f-5 4.294246f-5 2.2478362f-5 0.00010015502 8.0243284f-5 0.000101689846 -7.1336886f-5 -1.7529f-5 -2.449872f-5 0.00014522893 0.00014897573 -0.00015612852 1.28206775f-5; 1.4398359f-5 3.2457443f-5 -0.00029387916 -4.6320798f-5 -8.863483f-5 0.00016047833 0.00011034151 -3.815169f-5 -6.885017f-5 -4.075936f-5 -9.991774f-5 9.801379f-5 5.7009653f-5 -1.5231554f-6 -1.8516292f-5 -2.8484605f-5 0.00010737573 -3.947975f-5 -5.618711f-5 6.973052f-5 -8.047419f-5 -1.0998716f-5 -6.4371074f-5 0.00010291745 -7.3536365f-5 -3.1196283f-5 6.3143285f-5 5.5706387f-5 0.00018013634 2.792908f-5 -0.00013388673 -6.157502f-6; -0.0001964347 -0.00011927899 -9.8079654f-5 -3.8966034f-5 -2.9552055f-5 5.5855617f-5 -4.694224f-5 -3.8038594f-5 -0.00015996881 -1.4140501f-5 -0.00012714698 1.5393514f-5 -9.320554f-5 -2.2272056f-5 0.00024151863 0.00016515286 0.00012853983 4.668511f-5 -3.6681784f-5 -0.00017783654 -0.00013570293 -5.7195328f-5 9.173128f-5 4.3662254f-5 -4.9679547f-5 4.0555762f-5 -0.00021760406 8.167473f-5 2.2623724f-5 -2.493068f-5 -4.9671694f-6 0.00018992537; 5.737775f-5 -7.549049f-6 -0.00019681975 -6.78784f-5 -7.41357f-5 -0.00016236564 5.0159626f-5 0.00012603565 -6.342862f-5 -0.00014600584 -5.182984f-5 3.124025f-5 8.527024f-5 0.00012836856 -1.2367533f-5 -7.4468967f-6 3.0178877f-5 6.354632f-5 0.00019827831 0.00011174311 9.755978f-5 0.00013481562 7.9268775f-6 -0.0001198712 6.693998f-5 0.000112349175 -6.705077f-5 -9.249447f-5 0.00021250913 -1.2647807f-5 -1.4687137f-5 -0.0001278586; -1.885313f-5 2.4443145f-5 -6.5784996f-5 1.3019471f-5 0.00020979207 -0.00018175204 -3.3247616f-5 0.0001514494 7.141005f-5 9.0260924f-5 -2.8416369f-5 6.3499545f-5 7.4311656f-6 3.864039f-5 5.6633467f-5 -0.00019341717 -3.2725493f-5 -7.4408505f-5 2.442655f-5 4.20549f-5 7.735792f-5 -0.00013161304 -7.858171f-5 -7.521495f-5 -6.103639f-5 -9.230146f-5 -2.6068308f-5 -1.3642723f-5 -3.500341f-5 2.7376756f-5 5.5586715f-5 -0.0002424921; -7.726509f-6 -0.00010099817 -3.20016f-5 0.00016212618 -1.46627f-6 0.00011580428 1.20140185f-5 -9.7648284f-5 0.00017910384 2.2191425f-5 1.5116349f-6 -1.4704339f-5 -6.92027f-6 -6.8129975f-6 3.934899f-5 0.0001409655 -4.5696193f-5 -0.00024403422 -3.646131f-5 1.4705124f-5 8.009015f-5 -3.7039677f-5 -0.00018389421 -0.00010777318 -0.0001859717 4.1653162f-5 -2.7844068f-5 -7.197416f-5 -0.00010146512 -0.00022400953 -7.1817856f-5 6.236185f-5; 1.4897369f-5 4.231623f-5 -0.00013381435 -3.970811f-5 6.817642f-5 -5.2900617f-5 0.00016359973 -0.0001952099 -9.406169f-6 -4.954719f-5 3.8225167f-5 -4.3164422f-5 -1.2963211f-5 -6.235247f-5 -0.00010031445 4.708623f-5 -3.610987f-5 2.4581605f-5 5.6638924f-5 8.433728f-5 0.0002550415 -6.214299f-5 -7.6258857f-6 3.2787997f-5 -7.7069446f-5 -0.000100369594 -0.00010209731 -7.4618f-5 -0.00013067485 -1.9353625f-5 -8.111027f-5 8.24392f-5; 8.960741f-5 0.00019724156 0.00015462135 2.9635972f-5 0.0001593961 3.7151967f-5 6.3679354f-6 -8.666865f-5 -0.0001068992 -0.00019187853 0.00015730977 1.7624925f-6 0.00014806377 6.354459f-7 8.3125255f-5 0.00013924995 1.8685347f-5 0.00012145554 4.4409953f-5 0.00016537831 0.00014839771 -0.00010969996 -0.0001184646 -4.1579504f-5 1.4801489f-5 0.0001292968 4.3741247f-6 -1.2043425f-5 -0.00012699775 -8.224768f-5 -3.6266083f-5 2.5370216f-5; 0.0001293436 -8.392708f-5 1.309504f-5 -2.457392f-5 7.811217f-5 -3.972289f-6 0.00021810905 -2.0620373f-5 -6.270751f-5 3.3871598f-5 6.8864238f-6 -9.0639305f-6 9.530039f-5 0.00015639012 0.00011057578 -6.7414767f-6 0.00014669566 1.1569562f-5 -4.1980762f-5 -0.000113010086 2.7441303f-5 -0.0002836302 0.00026659408 -3.419725f-5 -1.3990363f-5 6.84571f-5 -8.532784f-6 -1.8810777f-5 -0.00010572384 -0.00028054402 -3.2482985f-5 -2.324051f-6; 3.991473f-6 -1.5515398f-5 -0.00015442113 -8.03961f-5 0.0001120888 -1.38606665f-5 -4.5125747f-5 1.3196238f-5 4.1471692f-5 -0.000102413935 -0.00016438442 -8.499651f-5 4.336627f-5 -6.317252f-5 -5.999334f-5 0.00017749694 -1.3987482f-5 -8.4993655f-5 -8.641106f-5 -0.00010763989 0.00025578227 6.9069865f-5 -0.0001065085 7.336336f-5 0.00011119968 6.040924f-5 8.8435176f-5 9.1566195f-5 -3.660159f-5 -6.77505f-6 -1.1555115f-7 -6.3658015f-5; 3.1354543f-6 2.9152177f-5 -0.00014291164 7.4214426f-5 8.710718f-5 -8.4190506f-5 -4.4497825f-5 0.0001168902 5.6892488f-5 0.00015479216 -0.00014288144 -0.00014104888 5.3111642f-5 -0.00010691237 -5.9395676f-5 -0.00031708306 -0.00012935478 -3.540072f-5 -6.089845f-5 -1.1351615f-5 0.0001030787 2.2343387f-5 1.9484452f-5 -4.4928605f-5 1.3636796f-5 -6.3579496f-6 -0.00012646658 3.919319f-5 -6.258364f-5 -8.712005f-6 0.00013614254 -0.000107833235; 6.229929f-5 -2.2271228f-5 -7.58702f-5 3.6599882f-5 -8.671693f-5 -2.5245861f-5 0.000113185204 -0.00013966365 4.668133f-5 -0.00012114791 -5.44134f-5 1.6244934f-5 -0.00014392158 -1.6995558f-5 8.191975f-5 -8.5246495f-5 0.00024936238 -2.92088f-5 -8.638919f-5 0.00018667645 7.906948f-5 -9.467499f-5 -3.2949323f-5 0.00016831954 -3.473756f-5 2.538361f-5 -0.00011548636 4.25295f-5 -6.0093225f-7 0.00014725166 -8.5497944f-5 -8.32532f-5; -0.00016864079 -2.873618f-5 7.649325f-5 -3.3309723f-5 0.00011391369 0.000109745 2.68027f-5 -4.9279395f-5 9.769898f-5 3.0485953f-5 -0.000107094456 0.00014613089 7.7494886f-5 -0.00022424967 -7.1299044f-5 -0.000115787014 -0.00014264238 1.1390487f-5 4.6128775f-5 -8.826007f-6 -0.00035717175 9.965132f-5 -0.0002540176 -4.4327557f-5 9.308038f-8 1.6615031f-5 -7.244013f-5 2.6939926f-5 0.000121039615 -7.944743f-5 1.5331561f-5 -0.00013789703; -8.519436f-5 -1.6422325f-6 0.00012381554 -9.066643f-5 0.00010727126 9.613937f-5 7.591667f-5 0.00014215847 2.1942931f-5 -3.1513086f-5 -0.000104561186 3.055686f-5 -0.00011274814 -0.00013097693 -6.373863f-5 -7.441696f-5 1.6399146f-5 -0.00015531115 -0.00020659683 -4.359769f-5 -1.5500504f-5 -1.550451f-5 0.00014970465 1.9846284f-5 0.00011537885 -2.097749f-6 0.0001436089 2.3048236f-5 2.0166059f-5 2.3955945f-5 -1.4598792f-5 5.758373f-5], bias = Float32[0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0;;]), layer_4 = (weight = Float32[-2.444522f-6 0.000112879185 0.0001558351 9.196768f-6 -8.893825f-5 -8.4488376f-5 2.2426927f-6 -5.6865458f-5 -1.8749115f-5 9.377462f-5 2.4503804f-6 8.707287f-5 -0.000106156476 -9.007546f-5 -0.00010368377 -0.00012888272 -8.414403f-6 -6.5748565f-5 0.00011486057 6.840657f-5 -7.89175f-5 -8.465059f-5 -2.0313388f-5 1.1040699f-5 -1.652368f-5 -0.00023326266 1.8083672f-5 2.1020029f-5 -0.000107340056 6.970476f-7 5.845529f-5 6.790722f-5; 0.0001303153 -5.0481543f-5 -0.00017039261 8.006863f-5 8.246246f-5 2.0142534f-5 -9.491559f-5 -5.407466f-6 3.3323842f-5 3.4727516f-5 -7.562626f-5 0.00014555441 -3.4489447f-5 -1.9135081f-5 4.9328537f-6 -8.310666f-5 3.124505f-5 6.8021596f-5 0.000120236116 0.00014720758 -2.8456117f-5 -2.9153392f-5 -0.00010461339 -1.3453715f-5 -3.6473368f-5 0.00014593621 3.1647658f-5 8.535252f-5 -8.636038f-5 -4.110334f-5 8.381028f-6 -7.8705365f-5], bias = Float32[0.0; 0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most DL frameworks, Lux defaults to using Float32, however, in this case we need Float64

julia
const params = ComponentArray{Float64}(ps)

const nn_model = StatefulLuxLayer(nn, st)
StatefulLuxLayer{true}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(broadcast), typeof(cos)}(broadcast, cos)),
        layer_2 = Dense(1 => 32, cos),  # 64 parameters
        layer_3 = Dense(32 => 32, cos),  # 1_056 parameters
        layer_4 = Dense(32 => 2),       # 66 parameters
    ),
)

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)

    axislegend(ax, [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"]; position=:lb)

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    loss = sum(abs2, waveform .- pred_waveform)
    return loss, pred_waveform
end
loss (generic function with 1 method)

Warmup the loss function

julia
loss(params)
(0.1755850988956685, [-0.02426111910576223, -0.02347645122029886, -0.022691783334835243, -0.021365969389331656, -0.019471305086901584, -0.016967578252165612, -0.013801998169138888, -0.009906239162539791, -0.005198714422377064, 0.0004157639798289335, 0.007036769176107432, 0.01474730556026759, 0.023566200118488166, 0.0333197669452577, 0.04336370528619235, 0.051925694639663164, 0.0547149674968884, 0.042653942218975074, 0.0021999808506005555, -0.06607287145320692, -0.11035535227510505, -0.07652159436878266, -0.006929494493876874, 0.03887234928587742, 0.05436526526391928, 0.053009142768105075, 0.04485563944382503, 0.034798438490344176, 0.02488248804529451, 0.01586751458130129, 0.007967068810792646, 0.001178879637523359, -0.0045786903875267865, -0.009407336262382483, -0.013405692480112766, -0.016658207147683835, -0.01923663382080456, -0.021197027884398537, -0.022581951935452562, -0.02342139344765542, -0.02373375966543449, -0.023526245934351604, -0.022796130580958487, -0.021528449986342478, -0.019698142210387126, -0.01726771552845578, -0.014186519660477957, -0.010390214179734047, -0.005800030012946644, -0.00032435256920097485, 0.006133764319610618, 0.013660688800939028, 0.02228403980851572, 0.031869458694453894, 0.041868963468044314, 0.05074977776383226, 0.054755316854178296, 0.04574321977051289, 0.010490945772289633, -0.0551447293426093, -0.10818452403906309, -0.08613318379001213, -0.01683090195396896, 0.034334959295914146, 0.05365651287579399, 0.053977044227627766, 0.04633713711917263, 0.036302926373076104, 0.026233593365580236, 0.01702171595941834, 0.008925490365459915, 0.001967218881961864, -0.0039408858302297945, -0.008892249783583567, -0.012997768789891551, -0.016337944766973325, -0.018994757519351525, -0.02102087500787412, -0.022466817282144193, -0.023360961022414094, -0.02372641793552793, -0.023571504140187657, -0.02289447548559312, -0.021684429945856773, -0.01991822584843278, -0.017557662111751453, -0.014560713540167528, -0.010858860548438678, -0.006384564113909405, -0.0010414150085042647, 0.005257247501767055, 0.012606313292655281, 0.02103553563361087, 0.030449483135275825, 0.040377437224583916, 0.04950077538147726, 0.05453326188754803, 0.04820824256359265, 0.01790701340221654, -0.044065657059072436, -0.10393761502987695, -0.09456407935529938, -0.027373122711411648, 0.028979626318481066, 0.05253159177874719, 0.05480345117574776, 0.047798162456091815, 0.03783011974548098, 0.02762188062448431, 0.01820747933366149, 0.009917544385011226, 0.0027765210769050996, -0.0032792541309421375, -0.008364433610233353, -0.012574236849541818, -0.01601019290459573, -0.01874237593296127, -0.02083967684588146, -0.022344597117855212, -0.023296045063698975, -0.023713984079343827, -0.023611168194971707, -0.022988484925680102, -0.021834146235097194, -0.020130271716180904, -0.01783959674089654, -0.014922940672593742, -0.011314714859625704, -0.006951034907422195, -0.0017379828626560732, 0.0044076563957848045, 0.011581813577204928, 0.01982206095204888, 0.029058815208794483, 0.038896303383762494, 0.04819551082419793, 0.05408656720944116, 0.050119437730996114, 0.02444105871175129, -0.033115097506984, -0.0978152640033257, -0.10147493777886196, -0.03838453879517387, 0.022754784844850922, 0.05093209760962993, 0.055456752873067144, 0.049227632118600324, 0.0393791983127865, 0.029040719486885306, 0.019433774899220924, 0.010936835445869122, 0.003615486267238955, -0.002602900735798512, -0.007814292793050964, -0.012140628444281385, -0.015669377526998398, -0.01848517048729763, -0.020648244400356428, -0.02221790318203472, -0.023225299378111415, -0.02369676319786393, -0.023646269573925, -0.02307641450368083, -0.021978825337411027, -0.020333138903045136, -0.018115218257488803, -0.015273648577038267, -0.011756380913547899, -0.007501855386595379, -0.002415412292607207, 0.003584387376186222, 0.010588489891749995, 0.018640056932756076, 0.02769945507334231, 0.03742950730564378, 0.04684757763801578, 0.053454511892943604, 0.05153631861538116, 0.03012486293319304, -0.022547344181311817, -0.09007179416277088, -0.10658138265586997, -0.0496414806954263, 0.01563428243864193, 0.04878501151362709, 0.05590551961757365, 0.05061181562930807, 0.04094203401527044, 0.030500413954193565, 0.02069109801156055, 0.011989779574920461, 0.004477351471843721, -0.001896677673821051, -0.007252569426478172, -0.011692316930253067, -0.015319916010115673, -0.018213774480559858, -0.020453233923777323, -0.022084506274614432, -0.0231494981678909, -0.023674102412438343, -0.023676471268542287, -0.02315922888568895, -0.022115704289502422, -0.02053381408528086, -0.018378881888003705, -0.015615039084470829, -0.012186301566930088, -0.00803575207639365, -0.003071913721700919, 0.002784477355332151, 0.009622618592788389, 0.01749240804955166, 0.026371888901212313, 0.035979821020977736, 0.04547019135810037, 0.05266552375127629, 0.05252185503244062, 0.03499485123632218, -0.012540213718930802, -0.08103382178663386, -0.10966793757921296, -0.060854994926291345, 0.007597698091568779, 0.04602934979974219, 0.05610660642777628, 0.05193478471942725, 0.04251840146036781, 0.0319899870601611, 0.021987491849130524, 0.0130755674181075, 0.005369298626616052, -0.0011724829923135298, -0.006670787881599578, -0.01122955929003314, -0.014957606443781732, -0.01793644499993944, -0.020249185757831325, -0.021944899344254264, -0.023068068710009753, -0.02364679097900573, -0.02370156757243481, -0.023236873742492385, -0.02224814251928889, -0.020725159028273342, -0.01863584427956667, -0.01594645506690939, -0.01260396339968507, -0.008554396354644845, -0.004504829309604322])

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l, pred_waveform)
    push!(losses, l)
    @printf "Training %10s Iteration: %5d %10s Loss: %.10f\n" "" length(losses) "" l
    return false
end
callback (generic function with 1 method)

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob, BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback, maxiters=1000)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-0.0001995816710402317; -8.233219705286294e-5; 0.0001119390290112088; 3.97237090510499e-5; 1.2679992323675025e-5; 7.962794916238225e-5; -9.48201050050406e-5; -1.031019110086252e-5; -3.0485765819346404e-5; 6.00349312662863e-6; -7.885362720104906e-5; 4.10265201935305e-5; -9.96883172774473e-5; -8.279307075994141e-5; -1.4096134691468018e-5; -7.49195460230117e-5; 9.036908886614352e-5; -2.2221209292156393e-5; 0.00024888131883899574; -1.4202300008027273e-5; 0.00017753802239876753; -1.0629531061555524e-5; 2.4538874640669865e-5; 1.4521903722193871e-5; -7.729492062930713e-5; -1.4483704035211457e-5; -0.00011247258225904423; -0.00010900453344226015; 2.9040107619963955e-5; 4.7808921408378555e-6; 0.00018319088849247926; 2.036648402279294e-5;;], bias = [-3.0482952942768077e-16; -3.2474755440106295e-17; 2.2532688448083705e-16; -2.035566704445742e-17; 8.972062109852735e-18; 5.359355281656958e-18; -1.759642642754663e-17; -9.096507853726758e-18; 2.794450552922776e-17; -1.5237739096394932e-18; -1.354510671556997e-16; 2.6978982141489522e-17; -2.544585556686923e-17; -1.651198344629086e-16; -3.500471390401383e-17; -7.766297441162514e-20; 1.5484186434185959e-16; -3.0314659810809664e-17; 3.0276347675219973e-16; -1.0802658900779026e-17; 2.0173674378212204e-16; -1.3889027685295928e-17; 4.653778188282393e-17; 4.5299856690810136e-18; -5.503130590855869e-17; -1.911019752690796e-17; -1.9259655004793251e-16; 6.535813190722836e-17; -1.659489096186948e-17; 5.953540565364387e-18; -9.87018899057946e-17; 2.5287294719479562e-17;;]), layer_3 = (weight = [-8.296906128571756e-5 6.126319237471039e-5 -5.420469063149007e-5 -6.03157161943401e-6 4.444310649948205e-5 -0.00010085432552780571 0.00011751893213902215 1.0014042682694348e-5 -4.763608342667053e-5 3.996229531427849e-5 -6.18591196082737e-6 -9.543172894094037e-5 0.00017989101649679848 2.559105484640176e-5 8.896821998011334e-5 -5.0849157223027505e-5 -5.3100214817574885e-5 4.0171788323885146e-5 0.00010814731683304551 4.828366251846789e-5 8.153309351718097e-5 -0.0002420665474613382 -2.961480061397384e-5 -6.846779399832696e-5 -6.516186806887028e-5 -0.00019479766308844074 -2.967877810926807e-5 -0.00012303752703639887 0.00016065011038396564 -0.0002049485407308861 6.517409962788393e-5 -0.0001166557118534294; 4.411182590305235e-5 0.00020232664019459941 4.0311880641472425e-5 0.00017390661878635854 -0.00018298498116309169 -2.8043454674104616e-5 0.00022360309731340262 6.015675671988749e-5 -0.00010068655182948634 -0.0001409474521071313 8.873659228641068e-6 9.522694517982873e-5 -9.564258152493922e-5 -6.8758333526319e-7 -0.0002483283192279411 0.0001005205972458609 7.148249779021538e-5 -0.0001536412769357334 -0.00011880695319373229 -7.133797271052718e-5 1.5842185805768797e-5 -0.00010930214420767849 -5.648438900351842e-5 5.984744120979338e-5 9.651794016718894e-5 -0.00014641190000219616 -0.0003333679655926815 -1.2667568373065466e-6 1.6207711726391458e-5 9.709923236012628e-6 0.00011689351731805084 -7.2690398608021e-5; -9.451749589394472e-6 2.0038560636433847e-5 0.00011599296378667536 -0.00010673930553488702 5.675066170047777e-5 3.1791489549190114e-5 8.190193004670816e-5 -9.972754521763141e-5 5.634520896242698e-5 -3.494958907717092e-5 -4.180501453115194e-5 2.0253816204473424e-5 5.916846605364442e-5 -4.104369834417101e-5 0.0001257738953642341 -0.00014814857747873783 4.230342931228065e-5 9.630208335116294e-5 -1.2040528942070016e-5 -7.472397881096424e-5 -0.0001228670112835619 6.469009206162576e-5 0.00011598175153599836 7.65971530414912e-5 4.3869062595744204e-5 -0.00023943515732420808 1.3808055358504097e-5 7.543414214856097e-5 0.00015396664496043465 -3.9891830580329504e-5 2.9167510023069083e-5 0.000126214658324586; -0.00018614220227701234 -4.251492630228378e-5 -0.0001290337348322299 7.311525057648425e-5 0.00011710068445299893 -1.4750527792239713e-5 -5.6537082694741135e-5 -5.158143887765561e-5 -2.385696445653552e-5 -1.0873482845817634e-5 -7.393057223762352e-5 -8.44841120760693e-7 0.0001443939874448903 0.00017491109990860356 8.902670758889198e-5 5.772300226144537e-5 8.56765874928888e-5 -2.616076993398977e-6 0.0001504573485149998 -9.027954386874075e-6 -0.0001874155676190691 -2.6027919214720332e-5 -5.7994587972097606e-5 0.0001457276995794024 -0.00010921525592802986 -0.00012703465092581552 1.1131723273682044e-5 5.4957872795601616e-5 0.00014796881093989258 -7.540890130567221e-5 4.40754048910039e-5 3.785912865769257e-5; 0.00019415519265985983 -0.00017559980558630846 7.807833382925425e-5 0.00013261429914805884 7.160264424115437e-5 9.862920335094655e-5 0.00016940410198244162 -4.385205109187378e-5 6.18147059164189e-5 0.0001340932685003806 7.479776461209395e-5 -1.1101322413509953e-5 -1.629879450534223e-5 7.518689010125733e-5 0.00016757907353408646 -0.00016841872645941263 3.537318473049432e-5 -0.00013181040055982796 -0.00010863782888684327 -7.675233435051503e-5 1.718912965099328e-5 0.00018902402727931768 0.0001661562745745476 5.3988078364143264e-5 -5.777742997211149e-5 4.0366980031337705e-5 6.880173525372911e-5 -0.0002197036047412966 -2.6361927736928398e-5 0.00014945768991552334 -4.455925234413948e-5 0.00010108222697695704; -7.933840476363373e-5 0.00010112353396216936 1.4265849525329807e-5 0.00013615506525930835 -1.4983298175856805e-5 -0.0002209511608878468 -8.659327665263383e-5 0.0001767488842535482 -7.695122855401672e-5 1.2354191706627344e-5 0.00013915086804736963 0.00012845171789486773 -2.475761509045277e-6 2.316574267134126e-5 3.051515107764362e-5 -0.00010330167657847855 3.2248919008639466e-5 0.00010993498784338074 7.576575938689734e-5 -0.00010091856762918679 -0.00014358780987714357 -1.195429337420997e-5 3.071270242181548e-5 9.64867626255095e-5 -9.615341202042425e-5 1.9900098262233307e-5 4.961011421364928e-5 5.122291389234602e-5 6.211538623866555e-5 -0.00011905586376705233 0.0002502862893681503 -1.3845406617101291e-5; 3.0959018924438215e-5 4.244343407839786e-5 -0.0001267983122561917 -0.00012123879669894424 -1.2930678811299298e-5 4.9317424138227925e-5 4.9962284985616445e-5 3.820588726006431e-5 2.0463937555123283e-5 3.309703360349852e-5 -0.0001235224142038148 0.00010971149212486154 -6.173605170993692e-5 2.489460102146277e-5 -0.0001411076616096219 -0.00010128812092085166 0.0001370087459617418 -1.7973320344081162e-5 -0.0001354689545296726 -4.239082291684126e-5 -7.421689470724831e-5 5.73576862001542e-6 -5.4117964016823064e-5 -9.927500896816154e-5 -0.00012152108202649995 7.872417090033169e-5 -9.811502487360016e-5 -1.054508932661692e-6 0.00012664465365736755 0.00011513034886951538 -1.921981375566765e-5 0.00016230391819766716; -0.00012202784305346134 6.117101261251783e-5 -0.0001292486270810555 -6.962825896126024e-5 0.00011433489239634215 -1.6616006496137095e-5 0.00019748062154603357 8.237197439511506e-5 -0.00018598296478531125 -0.00015853019463038805 -7.77647440824733e-5 8.050959118035816e-5 3.17731788984889e-5 2.9741171283872694e-5 7.18281441098006e-5 -0.0001848790419441723 -9.973752957834543e-6 -1.6995261145930963e-6 8.579851387393655e-5 -7.902509273635955e-5 -2.6680347536698904e-5 0.00014714788962946472 0.00010137661561229465 -6.78304425943716e-5 -7.017685161345459e-5 0.00011276073169055595 -2.097456340352848e-5 -5.7865556618439304e-5 6.860938415256526e-5 0.000131642707538318 6.94172845241585e-6 -1.71632640101089e-5; 5.3995080067180276e-5 -5.224906419555445e-5 -0.00026026214038522903 6.164203512214513e-5 8.924293649813597e-5 -1.7353942551768925e-5 0.000262744665066359 -9.169044062530184e-5 -3.918530057801455e-5 -4.325690463103845e-5 -3.604081359610663e-5 7.711435171274853e-5 -0.00017112467469217153 3.936605060205073e-5 -0.00013172898763825683 -0.0001536855576190886 0.0001112050507586652 1.2392201257694566e-5 8.533902871618246e-5 7.051814316271308e-5 8.091682536957154e-5 7.371956451295381e-5 0.00012445461322984147 8.339076009776877e-5 -0.00011529333425953746 6.0056607595993825e-5 2.7434857089766413e-5 8.946308514766433e-5 -0.00011718097232317139 4.031205605913394e-5 1.8360276513369047e-6 6.51364156297637e-5; -9.620677204776909e-5 -0.00012052602736488911 4.91130262732083e-5 0.00010185882953251336 -7.733346186555471e-5 -9.612964689708341e-5 -1.2141979402569765e-5 -0.00017034330134024398 0.0002190180452659218 -4.1377712966413494e-5 -2.947671104024665e-5 6.866137069905739e-5 5.653493156076933e-6 5.1496313483454415e-5 0.00011850854158424294 -4.882620172551914e-5 -0.0001575668652888842 -0.0001502380260147666 -0.00015709589255247849 -0.00019265646600374498 8.66368596932817e-6 3.3008461707732444e-5 6.27452168033879e-5 4.411941483932192e-6 -8.882398882391918e-5 5.986538550373619e-5 9.565330352657914e-5 -0.000123231053782856 0.00013966592765668907 -1.0551214357741688e-6 1.3015309798609233e-5 -6.830351825315037e-5; 0.00014043478155144578 0.00014586552721083574 -1.05812412039162e-5 3.4149448092638875e-5 -9.09426423641079e-5 -0.00010411038669190898 -0.0001718479804731115 -8.924735757328482e-6 1.1884442992585953e-5 -8.385959221414933e-5 -0.00013332457887373463 0.00011912770720299651 -1.1355494953030673e-5 8.160286530254821e-5 0.00013524539395802013 9.634335152440161e-5 -5.7598707589940464e-5 9.013387407311863e-6 0.00010731753287473262 5.336907930540211e-5 -6.337364606224995e-5 -5.840331572472761e-5 -6.968977389989122e-5 -9.548989759424411e-5 2.8841955336238496e-5 -0.00014465228445261426 -2.922852103439105e-5 8.575352991100908e-5 0.00010778727597425062 -0.00012160791515271822 0.0001716381041296662 -0.0002254927865033195; -6.989123630093344e-5 0.00017408998701968435 0.00013634898369447023 -0.00010844312327433642 6.095674968966996e-6 -0.00014905463012153974 -0.00011101846302379845 5.3436445348345266e-5 8.212927514934707e-5 -5.377651813658083e-5 3.658620790476235e-5 0.0001487792587863813 -0.00010333943190970531 0.00021884841863312955 4.7009955785617674e-5 1.8589786418642699e-6 5.655934092602776e-5 0.00011660612381776626 0.00011578069553021309 0.00017899436080144574 7.059337531926915e-5 -7.932479911003395e-5 -3.687503976150594e-5 -0.00016800489083150546 -7.4600147786289675e-6 0.0001516052843794771 -1.3510853016205022e-5 7.169742912766278e-5 -4.357361375921605e-5 2.550203772218306e-5 -9.510427025134984e-5 7.240423747818757e-5; -0.00011767377968133896 3.172334159539653e-5 0.0002325499435355986 -6.897323641120656e-6 8.97648196369618e-5 9.792198357944418e-6 -7.575653675662908e-5 4.075649445489909e-6 -4.195526802863772e-5 3.686825489790411e-5 3.400017790777297e-5 1.9148020288460826e-5 1.564623467497355e-5 7.860178353080159e-5 8.984813662760177e-5 7.83700029937758e-6 0.00014857209221740702 2.6984705033150552e-5 0.00012535427863985684 -0.00014617727087529156 5.0684738158474494e-5 2.1161072214504158e-5 3.4200234000358686e-5 7.327595718020789e-5 5.7838823482642784e-5 -5.7825270272537765e-5 -3.750724230845402e-5 -6.325145026854118e-5 -1.9601132447811102e-5 3.469247799881035e-5 3.692176592814818e-5 4.111798894523842e-5; -4.848572143920175e-5 7.77558443273475e-6 -4.479178190616847e-6 -4.835927257176556e-5 6.733653344468364e-5 1.55512675487868e-5 -1.2119736769325622e-5 0.00011425219190328541 0.00012973552429371547 0.00010530731151134877 -1.725410266176653e-5 0.000131604208487765 4.253933138583355e-5 1.3023256084621411e-5 0.0003730557705208022 -6.301368785302848e-5 -2.6277064992842907e-5 4.789993451192115e-6 -3.264323322905337e-5 2.0555714851577362e-5 0.0001258664756246651 0.00010876633814091952 1.7345753346086387e-5 1.5191320468654434e-5 4.1016702649855944e-5 0.0002024973850862417 3.977298319708634e-5 -5.6167895766941714e-5 3.335744758856529e-5 6.849918781568469e-5 -3.542742206582235e-5 -4.04964871481236e-5; 9.763679369912827e-5 -5.752700783798548e-5 5.1952511031205274e-5 2.8087477377063224e-5 -0.00014810479663449131 4.5979139004861354e-5 -8.94952409262507e-5 0.0001374474649955819 -8.155631988213219e-5 0.0001325544417077045 0.00010143214419347179 7.596074826412977e-5 5.382458582150535e-5 8.804605359704358e-6 -6.424001369762025e-5 1.2451678204319635e-5 0.0001816293788315062 -0.00011887020266751561 -7.637591809639393e-5 -0.00014632038258152806 4.7916995727939154e-5 6.45284653464633e-5 0.00013130396652830025 -5.2703804639370893e-5 -0.00010863694663296498 -1.0925366977083271e-5 9.174490512211476e-5 5.130941824151835e-5 -0.00022384905790115528 0.00012229186356396628 8.130809689733818e-5 2.4823097165587492e-5; -0.00022460085365827226 -8.592825379827764e-5 0.00012519953979272738 0.00010685370743367943 5.51601931822961e-5 6.46930036912694e-5 -0.00017428778137918292 0.00010671192812360912 5.777535788575294e-5 0.00016063694583345295 -0.00011905424817604729 1.713897817744995e-5 7.231755971070669e-5 3.2613332036186416e-5 4.552739296109399e-6 -1.2987779160260997e-5 -1.541922411496962e-5 6.435617050947719e-5 -0.00010344674429296095 -4.344448298588304e-5 -6.226579917500928e-5 -0.00032615558503102423 0.00011251966297486897 9.932791539818034e-6 -5.438374701561198e-5 -0.00013958657128177493 7.43125594279082e-7 0.00015180482095666562 -1.3336272973624408e-5 0.00011169535518484277 1.2270254769020988e-5 -6.303941536834369e-5; -0.0002477352697509688 0.0001000887262294428 -0.0001319413001523886 -1.1459414692074283e-5 -1.3051674342426726e-5 -0.00010246213942213147 -0.0002276898610047622 -7.770315433066927e-5 0.000245157598600196 0.0001419853636421577 0.00011474091142969035 5.5830593099524116e-5 2.0404493612326907e-6 -0.00012983416827540972 -5.096608738444396e-5 -0.00010011284184787636 -9.638975616456979e-5 5.833196546387297e-5 -0.00014700094625495479 -0.00013524176183768197 -6.494417330161047e-5 9.146673890933108e-6 -0.0001006582621825542 8.326790548249497e-5 -3.207783602323571e-5 0.00020708513174991477 3.5747196578600317e-5 0.00016056721351430034 -0.0001086694698655659 -0.0001015767790717691 0.00010123150268425993 1.7604851778344933e-6; -9.068391266472748e-5 0.00014179196257183569 -0.00024294250072406427 -7.732523265720979e-5 8.907443368381203e-5 -2.910458572576974e-5 -2.5391460828439015e-5 -3.504382805393697e-5 -1.5407701018454264e-5 2.0861574351250536e-5 6.25962125676895e-5 1.5650972026499661e-6 -1.4038408541716887e-5 4.962925920418703e-5 0.00012918357203753973 7.149718255542471e-5 -7.906811722102639e-6 0.00010819530832387816 7.363881702340793e-5 5.06007377888759e-5 -2.447917943490226e-5 -2.7930104285666388e-5 -7.071901238949661e-5 -0.00016543559506961768 4.2123381329395506e-5 -0.00011032059449197412 -7.761589988794e-5 0.00019580803161526282 1.6855068291000344e-5 9.590496640988496e-6 7.129639082367788e-6 3.305775278287317e-5; 3.696602080114253e-5 5.6543640042458395e-5 -2.2078361749081485e-6 -1.0656434184544806e-5 1.4839699277967292e-7 4.354977654971992e-5 -0.00019861779428223251 6.736541948578849e-6 -8.866166744750145e-5 -1.935331083271809e-5 9.363611179966356e-5 -1.8943741721647404e-5 1.393056636831156e-5 -0.0001109723219056253 7.488131969099707e-5 2.9490176260390924e-5 2.154043404831439e-5 -0.00012140359312664872 -0.00014297065926851513 -9.432295909986058e-5 4.2942288658619126e-5 2.2478188569460364e-5 0.00010015484595302337 8.02431111000835e-5 0.00010168967283194761 -7.133705939535209e-5 -1.7529173689159495e-5 -2.44988935233196e-5 0.0001452287558126621 0.0001489755538418315 -0.00015612869019290398 1.2820504437934922e-5; 1.439782694087147e-5 3.245691108185242e-5 -0.00029387969049258616 -4.632132992247765e-5 -8.863535991969255e-5 0.00016047780003299375 0.00011034098055114453 -3.8152223045301816e-5 -6.885069966648429e-5 -4.075989351241591e-5 -9.991827470349713e-5 9.801326008181376e-5 5.700912073388195e-5 -1.523687377640755e-6 -1.851682360558043e-5 -2.8485136557548465e-5 0.00010737519836419662 -3.9480281760745806e-5 -5.618764122464295e-5 6.972999029693457e-5 -8.047471861919516e-5 -1.0999247809934613e-5 -6.43716056072778e-5 0.00010291692081972288 -7.353689658322008e-5 -3.119681501976884e-5 6.314275300263904e-5 5.570585484400855e-5 0.00018013580769201084 2.7928547875660266e-5 -0.0001338872615304412 -6.158033887743619e-6; -0.0001964367911097871 -0.00011928108137748563 -9.80817421334327e-5 -3.89681226336003e-5 -2.9554143064384072e-5 5.585352851245162e-5 -4.6944329322504125e-5 -3.804068178237299e-5 -0.0001599708993564901 -1.414258890531146e-5 -0.0001271490691108203 1.5391426178899038e-5 -9.320762905939628e-5 -2.227414365574142e-5 0.00024151654365755747 0.00016515077009149389 0.00012853774392329823 4.6683020674895064e-5 -3.66838721203745e-5 -0.00017783862854070832 -0.00013570501318847374 -5.7197415979540664e-5 9.172919094612972e-5 4.366016587642145e-5 -4.968163552643215e-5 4.0553673945172994e-5 -0.0002176061461325265 8.167263881820238e-5 2.2621636209582723e-5 -2.493276765403226e-5 -4.96925754143447e-6 0.0001899232811881616; 5.737882025860736e-5 -7.547979237260239e-6 -0.00019681867701145316 -6.787732849041046e-5 -7.413463021073487e-5 -0.00016236457298653413 5.016069559353099e-5 0.0001260367208146118 -6.342754926443123e-5 -0.00014600477146593847 -5.182877060868053e-5 3.1241321323634565e-5 8.527131291195414e-5 0.00012836962565636033 -1.236646306786219e-5 -7.445827066093453e-6 3.0179946463631508e-5 6.354739167584098e-5 0.0001982793838236825 0.00011174417892327534 9.75608491557169e-5 0.00013481669162722308 7.927947137816686e-6 -0.00011987012834431601 6.694105110220413e-5 0.00011235024436466398 -6.704970286371618e-5 -9.249339922242272e-5 0.00021251019649064148 -1.2646737502633982e-5 -1.4686067459105926e-5 -0.0001278575290940144; -1.8854747621448303e-5 2.444152755871682e-5 -6.578661327564986e-5 1.3017853917893907e-5 0.00020979044870856603 -0.0001817536560780205 -3.3249232954933216e-5 0.00015144778967752503 7.140843604376035e-5 9.025930717272133e-5 -2.841798618518044e-5 6.349792761412914e-5 7.429548305861783e-6 3.8638774414190935e-5 5.6631850158203554e-5 -0.00019341879146721476 -3.272711023653984e-5 -7.441012189573203e-5 2.4424932918417138e-5 4.205328309709595e-5 7.735630583868529e-5 -0.00013161465380888982 -7.858332742088064e-5 -7.521656720981543e-5 -6.103801047045979e-5 -9.230307682499831e-5 -2.6069925523715875e-5 -1.3644340695991894e-5 -3.5005027028625324e-5 2.7375139099069944e-5 5.558509815403532e-5 -0.00024249371403910753; -7.72926367169732e-6 -0.00010100092666217305 -3.200435574508868e-5 0.00016212342117686112 -1.469025079328974e-6 0.00011580152451532675 1.2011263351079739e-5 -9.765103939682087e-5 0.00017910108246613527 2.218867004447303e-5 1.5088797625041275e-6 -1.4707094379641853e-5 -6.923025267685418e-6 -6.815752641350807e-6 3.934623484272675e-5 0.00014096273909561987 -4.5698948525759e-5 -0.00024403697651106672 -3.646406665316693e-5 1.470236904313722e-5 8.00873919786809e-5 -3.704243252397357e-5 -0.00018389696677442434 -0.00010777593459462579 -0.00018597445640008213 4.1650406927859216e-5 -2.7846822632640854e-5 -7.197691346046238e-5 -0.0001014678757939838 -0.0002240122896921247 -7.182061133897512e-5 6.235909197019599e-5; 1.4895572736195015e-5 4.2314435109627475e-5 -0.00013381614423832428 -3.970990526267533e-5 6.81746254720603e-5 -5.290241331221464e-5 0.00016359793646716737 -0.0001952116932208772 -9.407964693487262e-6 -4.954898629348089e-5 3.8223371507963807e-5 -4.316621796005604e-5 -1.2965006657568257e-5 -6.235426424081336e-5 -0.00010031624567472664 4.708443286818482e-5 -3.6111664098089644e-5 2.457980865602113e-5 5.6637128569349205e-5 8.433548775956678e-5 0.000255039697936649 -6.214478942110496e-5 -7.627681565459648e-6 3.278620117799094e-5 -7.707124184361608e-5 -0.00010037139015748452 -0.00010209910267275318 -7.461979710041151e-5 -0.00013067664128738702 -1.9355421112955718e-5 -8.111206859591622e-5 8.243740142079726e-5; 8.961083276951833e-5 0.00019724498325087277 0.00015462476640533155 2.9639392418924573e-5 0.00015939952627816196 3.715538751277335e-5 6.371355634486654e-6 -8.66652329888664e-5 -0.00010689577847658784 -0.0001918751071527137 0.00015731318908803369 1.7659127523286395e-6 0.0001480671931902298 6.388661220779391e-7 8.312867485512933e-5 0.0001392533746228022 1.8688766981372495e-5 0.00012145895798751075 4.441337351158713e-5 0.00016538173131705217 0.00014840113054086609 -0.00010969654194484501 -0.00011846118231476375 -4.157608363600496e-5 1.4804909010143044e-5 0.00012930021385257085 4.377544978351289e-6 -1.2040004908330381e-5 -0.00012699433081722936 -8.224425927908185e-5 -3.626266279317516e-5 2.5373636712924805e-5; 0.00012934441482243203 -8.392627197721969e-5 1.30958479395797e-5 -2.457311158181491e-5 7.811297844924213e-5 -3.971481341835865e-6 0.0002181098608026835 -2.0619565398048665e-5 -6.270670557100165e-5 3.387240584302876e-5 6.887231470642356e-6 -9.063122811169365e-6 9.530119867591109e-5 0.0001563909246080183 0.00011057658419773162 -6.7406689730376046e-6 0.00014669646778671042 1.1570369756698456e-5 -4.1979954325925046e-5 -0.00011300927864325985 2.7442110954460788e-5 -0.00028362937875030383 0.0002665948923869634 -3.419644409845405e-5 -1.3989555477422856e-5 6.845790638650016e-5 -8.531976086347445e-6 -1.88099692557858e-5 -0.00010572303282501014 -0.0002805432085686718 -3.248217755285787e-5 -2.323243202083411e-6; 3.990914360752929e-6 -1.5515956197214595e-5 -0.00015442168501513098 -8.039665977312468e-5 0.0001120882386613613 -1.3861225174787814e-5 -4.5126306081238074e-5 1.3195679225735275e-5 4.1471133722642026e-5 -0.00010241449331706634 -0.00016438497391836854 -8.499706587363248e-5 4.3365713049843274e-5 -6.317307839244722e-5 -5.9993899644634295e-5 0.0001774963836806812 -1.3988040568529188e-5 -8.499421369824795e-5 -8.641162117318301e-5 -0.00010764044987345356 0.00025578171493644225 6.906930661452691e-5 -0.00010650906029232081 7.336279827095327e-5 0.00011119912391686678 6.040867969668231e-5 8.843461705064388e-5 9.156563623500959e-5 -3.6602147024406545e-5 -6.775608885529328e-6 -1.161098142189855e-7 -6.365857394021243e-5; 3.1327858426505467e-6 2.9149508883433373e-5 -0.00014291430674229534 7.42117572020014e-5 8.710451398773193e-5 -8.419317422404276e-5 -4.4500493494254526e-5 0.00011688753325987676 5.688981942629582e-5 0.00015478948804346912 -0.00014288411151820195 -0.00014105154517775193 5.3108973553478625e-5 -0.0001069150367026897 -5.939834412739535e-5 -0.0003170857290012521 -0.00012935745147724798 -3.5403389155100004e-5 -6.090111677499645e-5 -1.1354283777756305e-5 0.00010307603403027166 2.2340718679778045e-5 1.948178393039499e-5 -4.493127384076748e-5 1.3634127718700289e-5 -6.360618079917421e-6 -0.0001264692455503788 3.919052138489831e-5 -6.258630859402807e-5 -8.714673675705522e-6 0.00013613986990214502 -0.00010783590372621917; 6.229899295619251e-5 -2.2271523308458836e-5 -7.587049626266352e-5 3.659958670511037e-5 -8.671722466345764e-5 -2.5246156697844294e-5 0.00011318490865420747 -0.0001396639462163288 4.668103352202224e-5 -0.00012114820569507667 -5.44136972584372e-5 1.624463871309196e-5 -0.00014392188026789366 -1.6995853771672676e-5 8.191945667391757e-5 -8.524679000941871e-5 0.0002493620806072298 -2.9209096142890627e-5 -8.638948642867844e-5 0.0001866761550937232 7.906918759005557e-5 -9.467528875205553e-5 -3.294961865861782e-5 0.00016831924872718867 -3.473785622730682e-5 2.5383314236082346e-5 -0.00011548665210824227 4.25292047603451e-5 -6.012276989488977e-7 0.00014725136420935505 -8.54982398286046e-5 -8.325349776526798e-5; -0.0001686441146963722 -2.8739502399448324e-5 7.648992590049341e-5 -3.331304566445017e-5 0.00011391036508512341 0.00010974167792965982 2.679937727890586e-5 -4.9282717613730334e-5 9.769566071127285e-5 3.0482630732242196e-5 -0.00010709777876826623 0.00014612756325305264 7.749156332948954e-5 -0.00022425299272498672 -7.130236650909635e-5 -0.00011579033622610135 -0.00014264570660515167 1.1387164645659944e-5 4.612545211364298e-5 -8.829329148918497e-6 -0.0003571750758643797 9.964799659001295e-5 -0.00025402091438508895 -4.433087912247853e-5 8.975793775504101e-8 1.661170875677503e-5 -7.244345494171532e-5 2.6936603658501865e-5 0.00012103629245330311 -7.945075509696623e-5 1.532823892861595e-5 -0.00013790035615302004; -8.519429068759727e-5 -1.6421608790014588e-6 0.00012381560988511305 -9.066635651028339e-5 0.00010727133373516992 9.613943872760417e-5 7.591674399814929e-5 0.00014215854641302865 2.1943002838142264e-5 -3.151301391984889e-5 -0.0001045611145228949 3.055693266001905e-5 -0.00011274806568611986 -0.00013097685699098383 -6.373856124913871e-5 -7.44168876109515e-5 1.6399217901252092e-5 -0.0001553110789528917 -0.00020659675759010713 -4.359761917943688e-5 -1.5500432010541402e-5 -1.550443924419668e-5 0.00014970471854108765 1.984635561608157e-5 0.00011537891875273916 -2.097677435031796e-6 0.00014360896490767013 2.3048308111227846e-5 2.016613031524611e-5 2.3956016565406764e-5 -1.4598720764426292e-5 5.7583800998017243e-5], bias = [-1.855032071454212e-9; -2.010435484194216e-9; 1.6920454595525974e-9; 5.567713748361622e-10; 3.656098358701478e-9; 1.7293434308296313e-9; -1.2320998033769415e-9; 1.0154535261036615e-10; 1.0462614006701523e-9; -2.277863449694313e-9; -5.718357087341239e-10; 2.733038420480947e-9; 2.698430812472885e-9; 4.763475369568192e-9; 1.3369308154444078e-9; -6.277384227710118e-10; -2.0970232471753015e-9; 3.2702965975245947e-10; -1.7307369195663647e-10; -5.31930539031226e-10; -2.088152765126035e-9; 1.0696225427954904e-9; -1.6173076684972968e-9; -2.7551152169741296e-9; -1.7958647694381388e-9; 3.4202501641107216e-9; 8.077087457983028e-10; -5.58661578801201e-10; -2.6684926630507836e-9; -2.9544565714798754e-10; -3.3224429877929525e-9; 7.16398601170707e-11;;]), layer_4 = (weight = [-0.0006818293675188689 -0.0005665056523122616 -0.000523549759299519 -0.0006701881553611344 -0.0007683228417130926 -0.0007638732316637779 -0.0006771422014475052 -0.0007362503880978459 -0.0006981340187793433 -0.0005856101856142981 -0.0006769345423462713 -0.0005923118826428556 -0.0007855412237443182 -0.0007694598226029526 -0.0007830686566470464 -0.0008082676395373985 -0.0006877992261216983 -0.0007451334931818611 -0.0005645243565327377 -0.000610978356640351 -0.0007583023248083534 -0.000764035495688506 -0.0006996982551365961 -0.0006683440482099499 -0.0006959085318854217 -0.0009126472758584874 -0.0006613012424834696 -0.0006583648942900125 -0.0007867248088698685 -0.0006786878809112434 -0.0006209293769790405 -0.0006114777123154171; 0.00035691116275523584 0.000176114319718107 5.620326285614106e-5 0.0003066645238602707 0.0003090582402435736 0.0002467384030248133 0.00013168029479327105 0.0002211884284273378 0.0002599197275780103 0.0002613233699643823 0.00015096963189209278 0.0003721502464233762 0.00019210638647400627 0.0002074606245565633 0.00023152873299905395 0.00014348923373347377 0.000257840908756893 0.00029461748963386953 0.00034683201040023286 0.00037380347392693315 0.0001981397407681662 0.0001974424923413249 0.00012198248213650891 0.0002131421184547663 0.00019012250042503266 0.0003725319996173898 0.0002582435470742101 0.00031194841158330356 0.0001402354569670187 0.0001854925538357581 0.0002349768352252696 0.00014789052923758447], bias = [-0.000679384930616981; 0.00022659589429048274;;]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(compute_waveform(
    dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(ax, tsteps, waveform_nn_trained; marker=:circle,
        alpha=0.5, strokewidth=2, markersize=12)

    axislegend(ax, [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb)

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(LuxCUDA) && CUDA.functional()
    println()
    CUDA.versioninfo()
end

if @isdefined(LuxAMDGPU) && LuxAMDGPU.functional()
    println()
    AMDGPU.versioninfo()
end
Julia Version 1.10.4
Commit 48d4fd48430 (2024-06-04 10:41 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 4 default, 0 interactive, 2 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 4
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.5, artifact installation
CUDA driver 12.5
NVIDIA driver 555.42.2

CUDA libraries: 
- CUBLAS: 12.5.2
- CURAND: 10.3.6
- CUFFT: 11.2.3
- CUSOLVER: 11.6.2
- CUSPARSE: 12.4.1
- CUPTI: 23.0.0
- NVML: 12.0.0+555.42.2

Julia packages: 
- CUDA: 5.4.2
- CUDA_Driver_jll: 0.9.0+0
- CUDA_Runtime_jll: 0.14.0+1

Toolchain:
- Julia: 1.10.4
- LLVM: 15.0.7

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.623 GiB / 4.750 GiB available)
┌ Warning: LuxAMDGPU is loaded but the AMDGPU is not functional.
└ @ LuxAMDGPU ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/LuxAMDGPU/DaGeB/src/LuxAMDGPU.jl:19

This page was generated using Literate.jl.