Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux, ComponentArrays, LineSearches, LuxAMDGPU, LuxCUDA, OrdinaryDiffEq, Optimization,
      OptimizationOptimJL, Printf, Random, SciMLSensitivity
using CairoMakie

CUDA.allowscalar(false)

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end
one2two (generic function with 1 method)

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end
soln2orbit (generic function with 2 methods)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end
d_dt (generic function with 1 method)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0)<1e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio1 "mass_ratio must be <= 1"
    @assert mass_ratio0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end
compute_waveform (generic function with 2 methods)

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defiing a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(Base.Fix1(broadcast, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1e-4)),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1e-4)),
    Dense(32 => 2; init_weight=truncated_normal(; std=1e-4)))
ps, st = Lux.setup(Xoshiro(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-6.215363f-5; -4.7839574f-5; 0.00015146201; -0.000151337; 3.5000787f-5; 2.7854427f-5; -0.00014940996; 3.6730096f-5; 8.703845f-5; -8.741318f-5; -5.5544522f-5; 0.00012893965; 6.930595f-5; 0.00029030893; -5.0287457f-5; -8.838624f-5; -6.3532665f-5; 0.00012454047; -0.00017002253; -5.6763096f-5; 7.164081f-5; -3.1511647f-6; 6.220186f-5; -3.5101657f-5; 8.410985f-6; -4.933897f-5; -3.5775673f-5; -6.3360145f-5; -0.00010776173; 3.521968f-5; -9.2839204f-5; 0.00016278356;;], bias = Float32[0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0;;]), layer_3 = (weight = Float32[5.8262027f-5 0.00011413043 3.0613883f-5 -0.000107104555 3.89539f-5 -6.522347f-5 0.00014323382 1.0009757f-5 -0.00012474551 0.00014385053 -0.0001011316 -0.00013292492 6.915765f-5 0.00013105661 -9.601982f-5 6.562088f-5 0.00013870922 -0.00014629166 -0.00015983795 6.5794084f-5 0.00022424542 7.880576f-5 0.00014302289 -2.0618496f-5 2.011076f-5 9.982218f-6 -9.670668f-5 -4.462968f-5 -0.00024201465 -3.932332f-5 -6.36961f-5 -3.734165f-5; -0.00012752067 -6.697242f-5 9.075176f-5 -1.5289566f-5 -3.2197143f-8 -1.9411185f-5 -3.4235476f-5 0.00012610079 3.917339f-6 0.00013250837 -4.9450424f-5 4.1220326f-5 8.43856f-5 4.998967f-5 -3.8189104f-5 0.00013878485 2.5811529f-5 4.5835135f-5 5.428533f-5 6.0559716f-5 -0.0001511846 -3.752738f-5 -7.389504f-5 -4.439033f-5 4.6338926f-5 -3.4148532f-5 -0.00016501844 -4.073293f-5 7.5485244f-5 -6.694507f-5 -5.7413554f-6 0.00018908684; -2.528262f-5 -2.0005238f-5 1.2792177f-5 0.00010972871 -5.99499f-5 -4.505672f-6 -0.00012683896 1.4096045f-5 4.2667953f-6 0.000109058295 3.5502893f-5 7.323104f-5 -4.73568f-5 -7.590644f-5 0.00012224267 7.5575226f-5 -2.5715322f-5 2.3440547f-5 2.89151f-5 -0.000103131075 -0.00012895303 -9.9410405f-5 -5.6510402f-5 0.00014455043 9.681102f-5 -0.00014534101 -2.4380208f-7 -2.394964f-5 -2.2326476f-5 -7.4918375f-5 6.446751f-5 0.00015493143; 9.3113544f-5 -4.193219f-5 1.7369353f-5 -7.909097f-5 -0.000117229174 -4.246502f-5 2.9499057f-5 -0.000103530496 -0.00017182478 -6.7082285f-5 -4.678808f-6 -0.00017663691 -6.456664f-5 -2.4189718f-5 -0.00017915512 2.4364967f-5 -9.259941f-5 0.00010917898 0.00017640334 -3.5617773f-5 -9.475157f-5 -6.727584f-5 -5.6127632f-5 9.662877f-5 -4.1317904f-5 5.3204563f-5 7.424898f-5 2.8990818f-5 0.0001739964 0.0001040401 9.147316f-6 3.2423017f-5; 2.7212653f-5 -2.2595828f-5 -9.013398f-5 -7.8230914f-5 -0.00010255985 -3.8572693f-5 -6.0909708f-5 0.00018090068 0.00020218414 -0.00011969984 -0.00010776683 -9.5107614f-5 -0.00030448404 -7.619056f-5 -2.1651775f-5 -0.00012207357 8.583296f-5 0.00012818605 -3.352512f-5 0.00019962154 2.5521886f-5 4.272255f-5 -8.858213f-5 -3.7567588f-5 5.0079823f-5 -5.27312f-5 3.1720174f-6 -8.32254f-6 -0.0001045863 0.00014829829 -8.574072f-5 9.9410456f-5; -1.5369747f-5 1.60433f-5 0.00032239006 -0.00011894804 0.0001264982 0.00013720684 0.0001373411 5.4504385f-6 2.781403f-5 -5.2700747f-5 9.0780115f-5 -8.7622975f-5 3.4636647f-5 -4.228504f-5 2.7535392f-5 0.000109743414 0.00016666576 5.100671f-5 -0.00011682792 8.8771776f-5 -0.00017250184 -0.00013902487 -1.5057209f-5 0.00013833519 -7.900253f-5 9.936032f-5 3.5998753f-5 -3.1559768f-5 -0.00013258502 2.2093584f-5 -0.00017211042 -0.000114189315; -1.5752941f-5 -0.00014244819 -0.00011999296 0.00015319104 -7.756481f-5 4.066601f-5 3.451434f-5 -0.00016883876 -0.00030031285 0.00016729138 7.234866f-5 0.0001923409 0.00015907659 -0.00026098933 -7.085347f-5 8.208732f-5 0.0001577419 -4.538808f-5 -5.6924724f-5 -8.56555f-5 8.37105f-5 -0.00021203511 -0.00015068539 -5.7609388f-5 -9.028021f-7 -0.00012966008 -0.00017489343 -9.468509f-6 0.0001816555 -0.00014238287 -0.00016046593 -9.012232f-5; 4.6962727f-5 -9.188754f-5 0.0001979095 0.00014192084 3.5561075f-5 8.804627f-5 0.00017009047 -4.700047f-6 0.0001646285 2.1994663f-5 0.0001402198 -7.141821f-5 -0.00014351055 1.3236163f-5 3.0881325f-5 -0.00022609539 -3.9602223f-6 -9.032644f-5 -9.121021f-5 -7.1539806f-5 -5.8485035f-5 -0.00011256034 -7.559741f-5 9.903596f-5 -9.6577576f-5 0.00011871355 -6.0076443f-5 7.38433f-5 6.479226f-5 3.4850957f-5 -8.858951f-5 -3.597631f-5; -1.247356f-5 4.9365786f-5 8.334513f-5 -6.707356f-5 1.1696292f-5 0.00012564068 -9.347275f-5 -5.226121f-5 -2.559152f-5 5.2162533f-5 3.7843165f-5 0.00014268071 -6.478889f-5 1.8953837f-5 -0.000116283096 7.305014f-5 0.0002417459 4.247642f-5 0.00015789451 0.00020869724 8.353426f-6 9.733433f-5 1.0650671f-5 -0.000101519836 -8.331701f-5 0.00010319449 3.915515f-6 -2.7179274f-5 -3.4879188f-6 1.105256f-5 0.00017358527 -4.6637517f-5; -5.4332893f-5 -9.5465635f-5 -8.2952785f-5 3.7698836f-5 -2.4740692f-5 2.4854935f-5 -1.4494514f-5 8.511136f-5 5.5581797f-5 0.000119297925 -5.055708f-5 -0.00024865379 8.4500185f-5 -0.00017171804 -0.00020363899 -0.00010821279 -7.144557f-5 -0.0001040709 -3.5004537f-5 -0.00015017344 -0.0003293547 -9.37212f-5 -5.9422324f-5 -7.80203f-5 0.000146864 -2.6253172f-5 5.0286268f-5 2.5818417f-5 -8.3153434f-5 -4.295352f-5 0.00011429918 -8.340714f-5; 2.4064648f-5 -4.345962f-5 -2.6050533f-5 -8.675837f-5 -6.836668f-5 0.00010480646 -5.3243744f-5 -0.00013593976 5.4545824f-5 -3.643757f-5 -1.1548249f-5 0.000109669505 -9.805119f-5 -0.00014562788 -4.981238f-5 -0.00021741948 -9.723044f-5 -6.636689f-5 0.00011041657 1.18037215f-5 8.574018f-5 7.0452093f-6 0.00013185944 -0.00020989701 -7.7680634f-5 -5.3661606f-5 -8.569317f-5 3.8252463f-5 0.00015550398 -3.342001f-5 -0.00022293077 -0.00014476763; -0.00010452346 0.000108830405 1.2148175f-5 -1.6982749f-5 0.00014278125 -0.00013418865 -4.112532f-5 4.0614945f-5 -2.931141f-5 0.00018660874 2.7220498f-5 -7.824039f-5 -0.000106423155 7.7449746f-5 0.00012344157 -5.7723937f-6 6.823634f-5 0.0001306498 1.1832669f-5 -8.917672f-5 -4.528865f-5 5.4061336f-5 -0.000117154596 8.2644314f-5 7.918797f-5 -0.00015576088 -0.00019170265 1.7267737f-5 -3.5846704f-5 -0.000102188584 -1.6543809f-5 -0.00024388162; 2.2468004f-5 3.3346856f-5 -9.364029f-5 4.6072f-5 -4.805988f-5 6.988886f-5 6.555611f-5 -2.7067429f-6 0.00023465157 -0.0001592426 -0.00015601261 8.5428546f-5 3.088076f-5 -9.9020544f-5 4.052149f-5 -0.00010118321 5.6325254f-5 -2.3966584f-5 -0.00010306356 -4.692069f-5 -7.5198375f-5 -7.140273f-5 -2.6763799f-5 0.00010827641 -7.6403216f-5 4.2491592f-5 1.4352136f-5 -5.053314f-5 -3.92531f-5 -0.00015071167 -8.4761276f-5 -2.4300361f-5; 4.949985f-5 0.00013659641 -5.8884616f-5 0.00016168682 9.403377f-5 -0.0002023165 -9.198696f-5 -0.0001933505 1.3519738f-5 0.000105000574 -2.9817847f-5 0.0001706025 7.6517887f-7 8.912275f-5 7.902211f-5 1.5402087f-5 0.00015565146 9.745499f-5 0.00013882133 0.00011190831 9.046141f-5 -0.00017703729 5.640159f-5 -3.4943972f-5 2.8317103f-5 -3.4272084f-6 8.9997906f-5 -4.1235453f-5 -4.5611956f-5 0.000116993666 0.00014777895 8.907446f-5; -1.0871559f-5 -2.8972365f-6 -6.494627f-5 0.0001031088 -0.00020170052 3.790485f-5 0.00012688439 -7.029904f-5 -0.0001102703 0.00010602056 2.6997981f-5 1.8321794f-5 2.4578047f-5 -0.00013638416 -4.610365f-5 0.00015636029 9.770625f-5 -0.00014737286 -3.6386944f-5 8.425154f-5 3.7085f-5 7.295013f-5 4.7059257f-5 -5.175591f-6 6.233437f-7 4.4535816f-5 9.172302f-5 1.1437286f-5 -0.00012517066 -6.582043f-5 7.168568f-5 7.067142f-5; -6.3352986f-6 -0.00010558855 6.301901f-5 -7.2308096f-5 0.0002235429 4.714376f-5 9.96508f-5 9.4115145f-5 2.8260856f-5 3.2967055f-5 -9.730319f-5 3.1836098f-5 1.9537278f-5 -9.9328994f-5 -3.7119007f-5 0.00011752004 6.209501f-5 3.277808f-5 0.000113662005 9.6462194f-5 0.000113457274 5.6377543f-5 -9.212928f-6 -0.00018711254 -3.4428944f-5 -5.0626717f-5 -5.5463083f-6 0.00016302538 0.00011083861 2.7477174f-5 -3.1925825f-5 -4.5240944f-5; 0.0001152834 3.0216572f-6 -8.832308f-5 -5.6151122f-5 -0.000121704325 2.3036868f-5 0.0001426623 -4.0847787f-8 0.00015484967 6.0968585f-5 -9.809945f-5 7.835527f-5 7.555767f-5 -0.0001461892 9.968427f-5 -1.516975f-5 0.000124515 -9.583998f-5 -0.00010052209 3.7906375f-5 0.0001906334 7.034836f-5 0.00011445278 0.00010071716 -7.172525f-5 -0.00017351462 -7.79522f-5 4.272648f-5 0.00012029635 -3.143231f-5 -9.412922f-6 1.1290461f-5; 0.00012237388 0.00010136952 -8.14831f-6 -5.855067f-5 -0.00010960555 4.5055654f-5 1.6882883f-5 -0.00014892346 0.000101698715 7.738465f-5 -0.00016184244 0.00034254382 0.00011896214 6.629721f-5 -0.00013621809 -6.47582f-6 0.000100204896 -0.00012981275 5.703502f-5 -0.00010633344 1.4006341f-5 0.00013085187 -1.3825761f-5 0.00016771269 -3.5474495f-6 4.9671282f-5 0.00019055384 -0.00010915678 -7.717513f-5 -0.00013567558 -3.6553956f-5 8.208136f-5; 3.1407748f-5 -3.5949084f-5 0.00010137459 2.7060556f-5 2.7053102f-7 0.00015410015 -5.1578412f-5 6.204385f-5 3.193699f-5 6.4502456f-5 0.00012603962 0.00010126935 -8.362082f-5 7.7487784f-5 -4.8873044f-6 2.5337555f-5 -5.1445375f-5 -5.1513844f-6 -2.767903f-6 4.465061f-5 -1.8254339f-5 -0.00015734015 -8.700112f-6 -0.00011558234 -4.8127687f-5 -9.654473f-5 4.666573f-6 -8.551778f-5 0.000104710845 7.47412f-5 3.622029f-5 0.000107403364; 8.32235f-5 -1.4476366f-5 -2.7699438f-5 0.00013069861 4.4145985f-5 7.961726f-5 0.00013396857 -6.157701f-5 0.00017729012 7.092053f-5 5.2188025f-5 7.2532304f-5 -3.7055375f-5 -7.559554f-5 1.228008f-5 8.836619f-5 4.2754633f-5 8.8080196f-5 -9.59113f-6 -2.0220405f-5 -0.00013497862 9.294577f-5 -0.00011399219 4.583301f-5 -4.547503f-5 0.000101952515 -9.606876f-5 -7.177209f-5 -9.1942224f-5 -1.488631f-5 -0.0002159238 -1.8467716f-6; -0.00017403172 0.00012039652 4.880822f-5 0.00011095257 0.00017168066 4.6258683f-5 -7.2291994f-5 0.00022417033 4.5751345f-5 -0.00011648577 -3.7656508f-5 -0.00012891632 2.9957315f-5 -3.3848653f-5 1.4531099f-5 -2.4462002f-5 -5.4282613f-5 -0.00019622898 -7.4768854f-5 3.0503092f-5 -0.0001486807 3.882637f-5 -0.00019933355 -5.102829f-5 -4.1463387f-5 -6.864861f-5 5.7127418f-6 -0.00011453286 0.000101088706 -7.536747f-5 7.267258f-6 0.00014439737; 0.00017717501 6.9591275f-5 4.3690226f-5 -0.0001400232 -2.3636607f-5 -0.00015131135 -6.68716f-5 6.0007087f-6 -0.00023090505 -0.00022690075 -0.00012166241 -0.00025368747 -0.00016141379 -0.00014116596 2.343849f-5 1.9010182f-5 -0.0001813586 0.00014517858 -7.057667f-5 -0.00019251379 2.9176106f-5 -0.00016993308 -4.2865668f-5 -4.4911667f-5 0.00014493571 -5.7684592f-5 -5.3412605f-5 -0.00017383507 -0.00011772734 -3.259214f-5 7.958008f-5 0.00010368264; -7.681433f-6 -8.909701f-5 1.8553683f-5 9.269079f-5 6.572268f-5 8.250329f-5 0.00015366315 0.00021934262 0.00012938718 -0.00021552155 6.0171533f-5 -2.7473574f-5 2.7159493f-5 -7.441063f-5 -8.079247f-5 -2.6589118f-5 -1.2885892f-5 5.925725f-5 -0.00013622036 -7.801469f-5 2.1818933f-5 -4.618377f-5 7.5390103f-6 3.8665326f-6 0.00013373958 -3.9614395f-5 -0.00013069184 8.768247f-5 0.0001299382 0.00012238494 -7.740815f-5 0.00010868511; -8.673972f-5 0.00018752282 3.9684055f-5 -9.687451f-6 5.2453932f-5 -9.27225f-5 7.073892f-6 0.00011920174 -6.593175f-6 -0.0001621245 -8.183615f-5 5.9841892f-5 -2.1354075f-5 -8.209071f-5 0.00013554495 0.00010121475 0.000206503 0.00015145575 -0.00027937372 -2.20844f-5 -0.000109768655 5.691708f-5 0.00014290072 -3.6574143f-5 -8.551349f-5 -9.2877235f-5 6.537012f-6 0.00019176933 -0.00010393473 -5.165483f-5 9.45943f-6 2.4477444f-5; 1.581303f-5 -0.00023796853 -2.8839599f-5 6.94895f-6 -7.5664015f-5 8.85389f-6 8.748957f-5 0.0001138342 9.507335f-6 1.2049911f-5 5.117918f-5 0.00017760215 7.834263f-5 -5.534891f-6 -5.9675873f-5 3.1729442f-6 4.2789423f-5 -9.530759f-5 8.265047f-5 2.1526188f-5 5.983514f-5 -0.00014385596 3.4952074f-5 -4.167226f-5 5.7575264f-5 0.00010817661 -3.292535f-5 0.0001046927 7.866234f-5 -9.1739035f-5 0.00015456064 3.1810512f-5; -0.00021829562 3.4530294f-5 -2.6639735f-5 -6.76679f-5 3.2383276f-5 -0.00012216858 2.3017199f-5 -8.155453f-5 -8.358408f-5 -4.0715902f-5 0.00011267719 4.160624f-5 0.000109239714 -7.1800474f-5 -5.217137f-6 5.4525635f-6 -3.9479604f-5 -0.000113431 -9.935221f-6 7.3549156f-5 2.8629258f-5 4.767458f-5 8.8313616f-5 -0.00010601123 -0.00029482867 -0.000103817925 -8.411555f-5 2.7850263f-5 -3.6905298f-5 -5.1553958f-5 -6.8098375f-6 2.8414906f-5; 9.957256f-5 -7.702982f-6 0.000121408004 0.00010109259 0.00026289525 0.00019780677 -0.00015470207 -0.0001521882 0.00012445087 0.0001822352 0.00013588971 -8.4551735f-5 -0.00012137907 3.438123f-5 0.00016546623 -0.00018445539 -9.9342404f-5 -1.32023715f-5 5.6854413f-5 -5.594594f-5 -0.00011769705 -6.716253f-5 5.3415086f-5 0.00017330503 -0.00012794403 -6.0168095f-5 -2.1766884f-5 -6.698659f-5 0.00015472858 7.743228f-5 2.6077969f-5 0.000115425806; -1.888069f-5 -0.00014460596 0.00013395856 0.00017908281 2.5012558f-5 -6.0914026f-5 3.403034f-5 -5.1980995f-5 1.2573065f-5 -7.3670824f-5 -0.00013787678 -4.973043f-5 6.9624076f-7 -9.427746f-5 -5.52234f-5 4.8285856f-5 -2.0195885f-5 -2.3948856f-5 2.0601661f-5 0.00011449763 -2.3012751f-6 -0.00012805217 -0.00019656036 0.00013398578 6.929642f-5 5.6220808f-5 -8.657534f-5 0.0002076661 6.7012654f-5 -0.00018631309 7.831518f-5 4.9613045f-5; -7.215539f-5 0.00012338528 -2.3958758f-6 4.7523215f-5 -4.5945126f-5 0.0001534441 7.6007804f-5 0.00015268593 0.00012554169 9.584925f-5 -0.0001378673 -0.0001753667 -2.2997961f-5 0.00013416598 5.7550637f-6 -0.00019873622 7.38176f-5 -0.00022811533 -0.0001523798 -0.00012318887 9.173681f-5 0.00013112393 6.0601214f-5 8.591715f-5 8.297477f-5 0.00016363552 -4.1867337f-5 3.739638f-5 5.3894877f-5 -0.00013752244 -6.617077f-5 -0.00016490417; 0.0001367298 -3.261648f-5 4.4791075f-5 2.0726134f-6 -9.4118055f-5 4.3951954f-5 8.718649f-6 5.134851f-5 9.0021626f-5 -2.0560672f-5 4.491345f-6 -0.00016216288 -1.1456702f-5 -0.00019506359 0.00014525976 5.9188988f-5 -8.182869f-5 0.00024369085 -5.5828008f-5 1.7089546f-5 2.4843517f-5 0.000104218794 9.777083f-5 -6.188139f-5 -5.5938042f-5 -0.00012998325 3.3918244f-5 0.00014793669 9.1900394f-5 -8.738238f-5 1.542543f-5 -0.00017161651; 0.00015695009 0.000117184696 -0.0001572587 5.177981f-5 -8.019283f-5 -9.815502f-6 -2.4400922f-5 -9.508648f-5 4.2062133f-5 -0.00017471463 -3.270858f-5 9.908271f-6 -5.3776228f-5 5.558832f-5 4.466952f-6 -0.00013568826 -2.3291566f-5 7.570693f-5 -8.644662f-5 0.00022712519 5.2152314f-5 -1.6636164f-5 -4.794384f-6 4.068036f-5 -6.5889784f-5 -8.004949f-5 -2.9669507f-5 0.00013352981 -0.00015436686 0.00011167081 -1.4654411f-5 -0.00015323452; -2.9461826f-5 6.8590925f-5 -0.00012171094 -0.00011836626 -0.00010853122 0.00014201392 0.00020106377 -6.2380794f-5 6.357512f-5 -3.1155827f-7 9.301685f-5 9.9642115f-5 -7.138964f-5 5.0885815f-6 6.454082f-5 9.1660804f-5 -5.082537f-5 6.190584f-6 -6.0650917f-5 -3.3009714f-5 -3.3855427f-5 -5.0583512f-5 0.00015465172 -1.686994f-5 0.00010381959 -0.00020086455 -0.00015882874 -5.214386f-5 0.00014969076 4.3677293f-5 -9.406667f-5 -6.4827924f-5], bias = Float32[0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0;;]), layer_4 = (weight = Float32[3.018463f-5 0.00011127962 -0.00011255401 -2.1164975f-5 -0.0001900768 8.630377f-5 4.5715995f-5 -3.966679f-5 6.708387f-5 0.00014986689 -0.00015763774 9.594826f-5 0.0001307211 -5.668102f-6 -0.00014179894 -1.7131359f-5 5.820793f-6 8.9608824f-5 -2.2573895f-5 -3.0693416f-5 -0.00018939441 0.0001693484 0.00011408394 -3.0475076f-5 6.7545225f-5 -0.0001409898 0.000100346355 -8.661279f-5 -0.00010456137 -0.000113500224 6.918368f-6 0.0001023217; -5.006268f-5 8.206729f-6 0.00018038848 -0.00016183664 -6.433965f-5 -8.572355f-5 0.00015852605 0.00022077427 -9.7935235f-6 0.00012103822 0.00012834302 -0.000117005104 7.237987f-5 1.697464f-5 9.7772085f-5 7.940262f-6 0.00015429372 -5.862251f-5 -1.2826915f-6 -0.00017605376 -0.00010474867 -0.000114600225 1.9295674f-5 -6.177461f-5 -7.8105746f-5 -5.140639f-5 6.9116264f-5 -7.929815f-6 -0.00017014152 0.0001295078 -1.8335195f-5 0.000111846624], bias = Float32[0.0; 0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most DL frameworks, Lux defaults to using Float32, however, in this case we need Float64

julia
const params = ComponentArray{Float64}(ps)

const nn_model = StatefulLuxLayer(nn, st)
StatefulLuxLayer{true}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(broadcast), typeof(cos)}(broadcast, cos)),
        layer_2 = Dense(1 => 32, cos),  # 64 parameters
        layer_3 = Dense(32 => 32, cos),  # 1_056 parameters
        layer_4 = Dense(32 => 2),       # 66 parameters
    ),
)

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)

    axislegend(ax, [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"]; position=:lb)

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    loss = sum(abs2, waveform .- pred_waveform)
    return loss, pred_waveform
end
loss (generic function with 1 method)

Warmup the loss function

julia
loss(params)
(0.1800989917937656, [-0.024266635907772467, -0.02348160893779232, -0.022696581967811782, -0.021370139252574635, -0.01947452855671806, -0.016969465224864737, -0.013802057021505183, -0.009903839424664489, -0.005193045022400174, 0.0004257415877671567, 0.0070523547394426905, 0.014770033727875017, 0.0235976374040273, 0.03336078087957971, 0.043412364469584463, 0.05197169880880543, 0.054726766210495346, 0.0425591565768829, 0.001908910472275266, -0.06644690957714286, -0.1103648429386085, -0.07616831300438578, -0.006708664474369053, 0.03885273253635518, 0.05425277135560105, 0.052908443171484604, 0.0448013346746762, 0.034791601848077475, 0.024913378351941486, 0.01592472182586159, 0.00804049267846761, 0.0012603594368133884, -0.004495495925738982, -0.009327264456562583, -0.013332395608645091, -0.016594409168042827, -0.019184334623002938, -0.021157648563935132, -0.02255643820074943, -0.02341028292099017, -0.023737219469271684, -0.023544085401447163, -0.022827779898846913, -0.021572935411503538, -0.019754000865682297, -0.017332899085075185, -0.014258243491913854, -0.010464757115993924, -0.005872468537071841, -0.0003882284223495848, 0.0060868070769522584, 0.013641211429992429, 0.02230466226318443, 0.031943326242482924, 0.042004118635398636, 0.05093285490951226, 0.054911011157182506, 0.045661299600148544, 0.00983414617028628, -0.05627172378405606, -0.10849841270551336, -0.08515661241174532, -0.015961204285368357, 0.0344981160468159, 0.053451069643968945, 0.05372378317134975, 0.046163735139424386, 0.03623252680420606, 0.026251697508210863, 0.017104909103774442, 0.009051459084935654, 0.0021174055231829823, -0.0037810663134520833, -0.008734004380168248, -0.012849587682943588, -0.016206181389440813, -0.018884129155451485, -0.020934757817838854, -0.02240756749493573, -0.023330002781669346, -0.023724441027128833, -0.023598415384367886, -0.02294949653806446, -0.021765949744339902, -0.020023756631157477, -0.017683652417167715, -0.014702259134047652, -0.011009406052945536, -0.006535392660962897, -0.0011810857093720922, 0.005143605761667868, 0.012537628306840222, 0.021034840362932083, 0.030541998317513776, 0.04058260443899007, 0.049808510217107534, 0.05484274613789371, 0.04820821003566249, 0.017028278394606675, -0.04590171176348248, -0.10483056725550938, -0.0931482648376222, -0.025741254404039366, 0.029452063575540616, 0.05227844985252703, 0.0543967195307308, 0.04749028943714608, 0.03767982095193669, 0.02761406040921459, 0.018307380667002347, 0.010090119119682934, 0.0029921613416898315, -0.003044054580402047, -0.008127757997000874, -0.012349804960067094, -0.015808318654118303, -0.018570673290781343, -0.020703710708983896, -0.02224819446651297, -0.02324170437135099, -0.023702922437451857, -0.02364356299489565, -0.023063341295450287, -0.021949386133658755, -0.020282451152855215, -0.0180238013785514, -0.015132328119470543, -0.011540119746535205, -0.007180260925740662, -0.0019549970228576463, 0.004223710383109179, 0.011457608113356322, 0.0197903921949235, 0.029156581188324364, 0.0391552571824331, 0.04861352860587331, 0.054551867374327594, 0.050252754688926614, 0.023473139070676503, -0.0355608558719836, -0.09950608950864288, -0.09987656134985903, -0.03591731144972561, 0.023679367391418536, 0.05069250960176209, 0.05490201537759084, 0.04877117923356294, 0.03913216278496565, 0.028992989710465725, 0.019540224715818684, 0.011149444990871827, 0.003892807943244369, -0.0022939223037245684, -0.007499184148808834, -0.011838777048042462, -0.015395346147242344, -0.0182497882106674, -0.020459318449161903, -0.022081009622849188, -0.023144013708647205, -0.023673047076773063, -0.02368047674958848, -0.023167593926624902, -0.022124405063867085, -0.020528907130179935, -0.018354929013843783, -0.015548771489225678, -0.012055321563114656, -0.007809214430667372, -0.0027109657647002077, 0.003327014687523413, 0.010403120567440773, 0.01856853849358103, 0.027789945355195616, 0.03772644971076383, 0.04736020402588669, 0.05407148149954519, 0.05183855324367573, 0.029184493965986165, -0.02546464519791502, -0.092705418828434, -0.1051126149071204, -0.04632271631883849, 0.01716493836606892, 0.04863864298820128, 0.055216146732823164, 0.049994587614765024, 0.04058114957573975, 0.030397814428874743, 0.020793158345382833, 0.012235108651876507, 0.004812070071704662, -0.0015159202040639578, -0.006859307228194212, -0.011312062191554591, -0.014971862270201275, -0.017912101402068635, -0.02020838190315152, -0.021903792882912616, -0.02303779296133484, -0.02363402470669566, -0.023708923999870395, -0.023263163626901093, -0.022288254413196016, -0.020769954778543296, -0.018671373360668892, -0.0159536572236072, -0.012557273080792164, -0.008420745610541118, -0.0034468548144575128, 0.0024510272284665614, 0.009371024954302485, 0.017372986762272304, 0.026443412664315016, 0.03629954579126344, 0.04606101665099091, 0.05342516057265422, 0.05301523538378648, 0.03417884303164089, -0.015768709071722124, -0.08467055246249218, -0.10867227671402216, -0.05674138112797835, 0.009892576447938199, 0.0460753626329571, 0.055305838449893216, 0.05114728795974092, 0.04202641088900143, 0.031816834894986795, 0.022073243751881452, 0.01334558312841371, 0.005756553273615908, -0.0007223551069799499, -0.006199943388980478, -0.010770133827483992, -0.014533763794684247, -0.017566066750704424, -0.019945448449909427, -0.021717087122574138, -0.022922425859276698, -0.023586796947211983, -0.023728607585349722, -0.023349986573678472, -0.02244420694510988, -0.020998510594165045, -0.018978248945500764, -0.0163462153823128, -0.013045300774588507, -0.009016294980866639, -0.004987289187144917])

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l, pred_waveform)
    push!(losses, l)
    @printf "Training %10s Iteration: %5d %10s Loss: %.10f\n" "" length(losses) "" l
    return false
end
callback (generic function with 1 method)

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob, BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback, maxiters=1000)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-6.215363100633597e-5; -4.783957410835457e-5; 0.00015146200894353983; -0.0001513369934399253; 3.500078673818362e-5; 2.7854426662050848e-5; -0.00014940995606584718; 3.6730096326209456e-5; 8.703844650860685e-5; -8.741318015365537e-5; -5.554452218343429e-5; 0.00012893964594655884; 6.930594827278274e-5; 0.0002903089334714698; -5.028745727025768e-5; -8.838623762118448e-5; -6.353266508069379e-5; 0.00012454047100614457; -0.0001700225257079893; -5.676309592660566e-5; 7.16408103470753e-5; -3.1511647193796122e-6; 6.22018633293096e-5; -3.510165697658078e-5; 8.41098517412925e-6; -4.9338970711618536e-5; -3.577567258612491e-5; -6.336014484981907e-5; -0.00010776172712206163; 3.5219680285035135e-5; -9.283920371661393e-5; 0.0001627835590623408;;], bias = [-3.540268420568159e-17; -1.1511440219465359e-17; 2.658981402864399e-16; -1.4696527697388972e-16; 5.932310763061899e-17; 4.221115712139736e-17; -1.0304097838067351e-16; -1.12354476823274e-19; 2.2974216170065214e-16; -1.8299858061320895e-16; -5.2273653055001314e-17; 4.452913787414255e-16; 3.6094215978585734e-17; 8.740525560584779e-16; -6.247461837226484e-17; -1.375086262170008e-16; -1.8244113501738682e-16; 4.3698721717608095e-17; -2.230669748027857e-16; -1.3251546221905893e-16; 9.492511672277338e-17; -4.6003067601167516e-18; 9.107740279205176e-17; -4.515408158785387e-17; -6.459040355194215e-18; -7.749526566283796e-17; -4.933117277904469e-17; -5.32449060761712e-17; -6.855931848044483e-17; 5.052995939957442e-17; -1.1950332518210313e-16; 1.9218685880051537e-16;;]), layer_3 = (weight = [5.8262058355234286e-5 0.00011413046251595367 3.061391418645484e-5 -0.00010710452386798012 3.895392967000952e-5 -6.522343517668198e-5 0.0001432338491392727 1.0009788293013095e-5 -0.0001247454760052788 0.00014385055930665107 -0.0001011315665712173 -0.00013292488927641548 6.915768119119452e-5 0.00013105664640510718 -9.601979161766942e-5 6.562091095451606e-5 0.00013870925124114726 -0.000146291623769008 -0.0001598379143436035 6.579411512552175e-5 0.0002242454526092333 7.880579016249975e-5 0.00014302291912803773 -2.0618464423256242e-5 2.0110791065002555e-5 9.9822497029381e-6 -9.670664929240594e-5 -4.46296484190098e-5 -0.00024201461690632288 -3.9323288893006986e-5 -6.369606615431267e-5 -3.734161999114467e-5; -0.00012751989593484659 -6.69716394924297e-5 9.075253878731518e-5 -1.5288788156935455e-5 -3.141893645173665e-8 -1.9410407142593987e-5 -3.4234697844253207e-5 0.00012610156377090915 3.918117405964682e-6 0.00013250915039575344 -4.944964562414412e-5 4.122110444481878e-5 8.438638146512902e-5 4.999044872131755e-5 -3.818832588253727e-5 0.0001387856242643917 2.581230696419032e-5 4.583591330342417e-5 5.428610861312683e-5 6.056049428193629e-5 -0.0001511838275775059 -3.75266030033783e-5 -7.389426205295355e-5 -4.438955179752932e-5 4.633970424661291e-5 -3.414775378873753e-5 -0.0001650176639970128 -4.0732153459566054e-5 7.548602265783174e-5 -6.694428916775772e-5 -5.740577174174707e-6 0.00018908762253599387; -2.5282507913899747e-5 -2.0005126708084796e-5 1.2792288152311562e-5 0.00010972882097969679 -5.994978840196508e-5 -4.505560746721478e-6 -0.0001268388490922664 1.4096156127678653e-5 4.266906734599927e-6 0.00010905840696916888 3.550300477338944e-5 7.323114784351489e-5 -4.735668829259479e-5 -7.590633088573187e-5 0.00012224277686011972 7.557533769552568e-5 -2.5715210932204642e-5 2.3440658487132188e-5 2.8915210868333347e-5 -0.00010313096338965941 -0.00012895292223281244 -9.94102933242824e-5 -5.651029044263078e-5 0.0001445505427630711 9.681112948722206e-5 -0.0001453408962612879 -2.436906010383644e-7 -2.3949528832119097e-5 -2.2326364551801648e-5 -7.491826311768037e-5 6.446762069511196e-5 0.00015493154439293681; 9.311181548198093e-5 -4.193391959618529e-5 1.7367625155914626e-5 -7.909269715511702e-5 -0.00011723090228245405 -4.24667488862092e-5 2.9497328619784378e-5 -0.00010353222405158718 -0.00017182650862580554 -6.708401287526276e-5 -4.680536398363158e-6 -0.00017663863781692342 -6.456836508207566e-5 -2.419144602224993e-5 -0.00017915684674720016 2.4363238304225718e-5 -9.260113819821392e-5 0.00010917725358538384 0.0001764016085924475 -3.561950160858853e-5 -9.475330097687022e-5 -6.72775678997195e-5 -5.6129360195180754e-5 9.662704160816225e-5 -4.131963232269385e-5 5.320283508423972e-5 7.424725100078813e-5 2.8989090247497622e-5 0.00017399467815792027 0.00010403836841375757 9.145587506076757e-6 3.242128860236814e-5; 2.7210907505147122e-5 -2.2597573475426998e-5 -9.013572598245588e-5 -7.823265950332057e-5 -0.00010256159184405549 -3.8574437708808035e-5 -6.091145296135616e-5 0.00018089893383534664 0.00020218239042537679 -0.00011970158702365656 -0.00010776857270298387 -9.510935960474446e-5 -0.00030448578654227544 -7.619230636528589e-5 -2.165351979397596e-5 -0.00012207531726358053 8.58312134407666e-5 0.00012818430077846075 -3.352686580523171e-5 0.0001996197981536589 2.552014048364734e-5 4.272080441682562e-5 -8.858387336272565e-5 -3.756933328600567e-5 5.0078078133305195e-5 -5.273294554695386e-5 3.170272251959313e-6 -8.324285586926107e-6 -0.00010458804790300626 0.00014829654647132413 -8.574246639523203e-5 9.940871059758195e-5; -1.536835319642548e-5 1.604469340305741e-5 0.00032239145090057024 -0.00011894664479700413 0.00012649959258920993 0.00013720823059044346 0.0001373425011122459 5.4518327131181344e-6 2.7815424765103564e-5 -5.269935267209816e-5 9.078150920755626e-5 -8.762158089597923e-5 3.463804110721072e-5 -4.228364472444629e-5 2.7536786511302476e-5 0.00010974480843356994 0.0001666671568825783 5.1008105366014176e-5 -0.00011682652533568275 8.877317028305881e-5 -0.0001725004431493643 -0.00013902347112819955 -1.5055814437108236e-5 0.00013833658609725084 -7.90011354540515e-5 9.93617113279128e-5 3.6000147666311925e-5 -3.155837388407085e-5 -0.00013258362290729755 2.2094977758354633e-5 -0.00017210902573355327 -0.00011418792113477089; -1.5757447094554885e-5 -0.00014245269476717185 -0.00011999746703122355 0.00015318653249369387 -7.756931363886224e-5 4.066150513196298e-5 3.450983576431236e-5 -0.000168843262421917 -0.00030031735723652875 0.00016728687268867165 7.234415324589508e-5 0.0001923363940436432 0.00015907208371161076 -0.0002609938331301224 -7.085797761149515e-5 8.208281150971571e-5 0.00015773739659877904 -4.5392584347163825e-5 -5.692922996285041e-5 -8.566000388284854e-5 8.370599761790236e-5 -0.0002120396186951752 -0.00015068989369367792 -5.761389393640682e-5 -9.073080418914812e-7 -0.00012966458807980467 -0.00017489793599864461 -9.473015147275873e-6 0.00018165099545102223 -0.0001423873712197132 -0.00016047043964546954 -9.012682340152582e-5; 4.6963189373435306e-5 -9.188707481958896e-5 0.00019790995865472287 0.00014192130474251392 3.556153823732424e-5 8.804673124681259e-5 0.00017009093134541699 -4.699584411774757e-6 0.00016462895727594316 2.1995126001945916e-5 0.000140220258611897 -7.14177469548883e-5 -0.0001435100886796697 1.3236625630747783e-5 3.0881787819033546e-5 -0.0002260949288087997 -3.959759584596782e-6 -9.032597445772218e-5 -9.120974864933102e-5 -7.153934275853453e-5 -5.848457222460803e-5 -0.00011255987483897202 -7.559694786910914e-5 9.903642493876654e-5 -9.657711344194556e-5 0.00011871401219972416 -6.007598058791807e-5 7.384376366966387e-5 6.479271978017744e-5 3.485141932609516e-5 -8.858905058010567e-5 -3.597584801033496e-5; -1.246974355968801e-5 4.936960245753678e-5 8.33489429721451e-5 -6.706974458146549e-5 1.170010815325597e-5 0.0001256444993130047 -9.346893424598139e-5 -5.225739372622996e-5 -2.5587703966509157e-5 5.216634959718523e-5 3.7846981250130916e-5 0.00014268453011092861 -6.478507206271674e-5 1.8957653047099522e-5 -0.00011627927964515044 7.305395771321198e-5 0.00024174971982641902 4.2480234661272645e-5 0.0001578983246515274 0.00020870105432617403 8.357242186990714e-6 9.733814291461694e-5 1.0654487560476023e-5 -0.00010151601967596279 -8.331319626382237e-5 0.00010319830831644147 3.9193314121272565e-6 -2.717545798010366e-5 -3.484102571085142e-6 1.1056375988793389e-5 0.0001735890873176065 -4.663370058487323e-5; -5.433876987214253e-5 -9.547151177090946e-5 -8.29586618360054e-5 3.7692958982529404e-5 -2.4746568045402833e-5 2.4849058433145406e-5 -1.4500390819581579e-5 8.510548167262628e-5 5.557592038398331e-5 0.00011929204873608159 -5.056295532400922e-5 -0.0002486596616230947 8.449430850899902e-5 -0.0001717239186412342 -0.00020364486891909808 -0.00010821866936881328 -7.145144383090568e-5 -0.00010407677956522143 -3.501041402470269e-5 -0.00015017931338067342 -0.00032936056715103605 -9.372707910504048e-5 -5.9428200359857226e-5 -7.80261737540699e-5 0.00014685812010295843 -2.6259048639501898e-5 5.0280391120877256e-5 2.5812540740421725e-5 -8.315931091915328e-5 -4.295939771825068e-5 0.00011429330578407193 -8.341301628517393e-5; 2.4060362727531225e-5 -4.346390361440495e-5 -2.6054818056347064e-5 -8.676265807711545e-5 -6.837096535725796e-5 0.00010480217368915337 -5.324802933732156e-5 -0.0001359440424608465 5.454153923659268e-5 -3.644185509307618e-5 -1.155253395388385e-5 0.00010966522003168264 -9.805547543925941e-5 -0.0001456321691989132 -4.981666589862762e-5 -0.00021742376649177915 -9.723472559250775e-5 -6.63711756829372e-5 0.00011041228625569379 1.1799436458593727e-5 8.573589783503014e-5 7.040924253872535e-6 0.00013185515907701198 -0.00020990129943363644 -7.768491869917579e-5 -5.366589122111085e-5 -8.569745788242936e-5 3.824817812294814e-5 0.00015549969301668248 -3.342429355377608e-5 -0.0002229350570019631 -0.00014477191817835426; -0.00010452487241698687 0.0001088289931301181 1.2146763206810273e-5 -1.6984161097089263e-5 0.00014277984099473427 -0.00013419006483152902 -4.112673251645057e-5 4.061353264270911e-5 -2.931282198394737e-5 0.0001866073283833629 2.721908584655436e-5 -7.824179976002922e-5 -0.00010642456674238881 7.74483336370033e-5 0.00012344015648133738 -5.77380580059796e-6 6.823492675332104e-5 0.00013064838948205387 1.183125676628385e-5 -8.91781315788469e-5 -4.52900609291458e-5 5.405992431926109e-5 -0.00011715600759478544 8.264290171220523e-5 7.918655991099345e-5 -0.00015576228962734662 -0.00019170406504989045 1.726632525398612e-5 -3.5848116216835546e-5 -0.00010218999579072374 -1.654522076412611e-5 -0.00024388302755530154; 2.246567134980008e-5 3.3344523103458436e-5 -9.364261961771703e-5 4.606966729162661e-5 -4.8062212158677414e-5 6.988652864938265e-5 6.555377592568641e-5 -2.709075829536922e-6 0.00023464923549445523 -0.00015924493084193956 -0.000156014944082109 8.542621333154777e-5 3.087842821960304e-5 -9.902287740718197e-5 4.0519155731906926e-5 -0.00010118553939273388 5.6322921295668024e-5 -2.3968917162247647e-5 -0.00010306589422278184 -4.692302275098081e-5 -7.520070823838944e-5 -7.140505942980221e-5 -2.676613178218819e-5 0.00010827407356818835 -7.640554861164111e-5 4.248925949914169e-5 1.4349802654485014e-5 -5.053547386078257e-5 -3.9255433080228834e-5 -0.00015071400150655137 -8.476360847007229e-5 -2.430269444183241e-5; 4.950439528341256e-5 0.00013660095850367182 -5.8880070137001396e-5 0.00016169137074039153 9.403831631073966e-5 -0.00020231195390794116 -9.198241765252392e-5 -0.0001933459514110513 1.3524283749689365e-5 0.00010500511984003002 -2.9813301589335652e-5 0.00017060705271443183 7.697247407395091e-7 8.912729230360222e-5 7.902665514860348e-5 1.540663310163388e-5 0.00015565600755519988 9.745953706452618e-5 0.00013882587358247372 0.00011191285579128135 9.046595936536453e-5 -0.00017703274512301712 5.640613584558329e-5 -3.493942655028536e-5 2.832164886070826e-5 -3.422662539256817e-6 9.000245176151624e-5 -4.1230907081507e-5 -4.560741013836477e-5 0.00011699821174273503 0.00014778349417622283 8.907900904895972e-5; -1.0870773093383235e-5 -2.8964501953750158e-6 -6.494548089821796e-5 0.00010310958949130986 -0.00020169973087748856 3.790563694648111e-5 0.00012688517797951633 -7.029825011985725e-5 -0.00011026951635275543 0.00010602134769297626 2.6998767299038154e-5 1.8322580745832095e-5 2.45788329092742e-5 -0.00013638337206349399 -4.610286434316455e-5 0.00015636107180481263 9.770703816767366e-5 -0.00014737207622519475 -3.638615764347963e-5 8.425232279561065e-5 3.7085785680492604e-5 7.29509198035082e-5 4.70600430898327e-5 -5.1748046817816415e-6 6.241300124308786e-7 4.453660274603959e-5 9.17238072187605e-5 1.143807237977909e-5 -0.00012516986991530896 -6.581964354981083e-5 7.168646750952514e-5 7.067220629406097e-5; -6.332356016262922e-6 -0.00010558560877243214 6.302195418255142e-5 -7.230515309049074e-5 0.00022354584091672624 4.7146702335253275e-5 9.965374504543635e-5 9.411808735322141e-5 2.826379880288888e-5 3.296999732710688e-5 -9.730025028247616e-5 3.183904066544365e-5 1.954022029976897e-5 -9.932605150532513e-5 -3.711606456917888e-5 0.0001175229803528941 6.209795122131383e-5 3.278102251796497e-5 0.00011366494748369887 9.64651366565725e-5 0.00011346021658837372 5.638048553819831e-5 -9.209984974494e-6 -0.00018710959972949748 -3.442600115803945e-5 -5.062377437852491e-5 -5.543365724493631e-6 0.00016302832539522702 0.00011084155578727718 2.7480116322207655e-5 -3.192288256387055e-5 -4.5238001413165107e-5; 0.00011528527266038455 3.0235287808140717e-6 -8.832121074973965e-5 -5.614925084827823e-5 -0.0001217024529965591 2.303873934708597e-5 0.0001426641772935973 -3.8976227885000624e-8 0.00015485153726460058 6.097045643529498e-5 -9.809758030185555e-5 7.835713931982432e-5 7.555954089313562e-5 -0.00014618732366091942 9.968614340452433e-5 -1.5167878881205288e-5 0.0001245168767140233 -9.583811170031221e-5 -0.0001005202195288779 3.790824648733163e-5 0.00019063526433585004 7.035023172777254e-5 0.00011445464933914937 0.00010071903107147698 -7.172337627578423e-5 -0.0001735127500056293 -7.79503299769011e-5 4.272835012799924e-5 0.00012029821828165423 -3.1430439956765276e-5 -9.411050267393675e-6 1.129233242082067e-5; 0.00012237605343115454 0.00010137169529925147 -8.146136943763171e-6 -5.8548497955710665e-5 -0.00010960337594142994 4.505782790956239e-5 1.6885056273965548e-5 -0.0001489212830084458 0.00010170088872554387 7.738682116622706e-5 -0.0001618402696213779 0.00034254599241547253 0.00011896431325011757 6.629938225963039e-5 -0.00013621591851106812 -6.473646673031538e-6 0.00010020706914373595 -0.00012981057288117656 5.7037193718199105e-5 -0.00010633126866296886 1.40085146166625e-5 0.00013085404289894944 -1.3823587588475266e-5 0.0001677148590796546 -3.5452760958884244e-6 4.967345531339793e-5 0.00019055601088338646 -0.00010915460942770431 -7.717295583562602e-5 -0.00013567340855943853 -3.655178287873443e-5 8.208353184170301e-5; 3.1409351256520564e-5 -3.5947480748795014e-5 0.00010137619658729286 2.7062159394912826e-5 2.7213439120403365e-7 0.00015410175813253485 -5.157680886674417e-5 6.20454547310569e-5 3.193859349943791e-5 6.45040590165471e-5 0.0001260412272380977 0.00010127094986040699 -8.361921670489707e-5 7.748938781096425e-5 -4.885701012021097e-6 2.533915806219994e-5 -5.1443771619738104e-5 -5.14978098413994e-6 -2.76629959677781e-6 4.465221257073194e-5 -1.825273563786655e-5 -0.00015733854986370053 -8.698508347656511e-6 -0.00011558073772226037 -4.8126084104799486e-5 -9.654312914113153e-5 4.6681765774761025e-6 -8.551617527023611e-5 0.00010471244870808791 7.474280112314461e-5 3.622189517722505e-5 0.000107404967754912; 8.322462166015701e-5 -1.4475246432446863e-5 -2.7698318051449504e-5 0.00013069972849970632 4.414710495675721e-5 7.961838058017947e-5 0.00013396968937067288 -6.157589156591533e-5 0.00017729123577064413 7.0921646722672e-5 5.218914451015851e-5 7.253342412932024e-5 -3.7054255487782394e-5 -7.55944208935316e-5 1.2281200026883662e-5 8.83673121576061e-5 4.275575268602634e-5 8.808131609166261e-5 -9.590010238215631e-6 -2.0219284822268276e-5 -0.00013497749820762646 9.294688719146735e-5 -0.00011399106803528323 4.5834130317054614e-5 -4.547390996540118e-5 0.00010195363505099663 -9.606764139552798e-5 -7.17709706508677e-5 -9.19411038625204e-5 -1.4885190255334216e-5 -0.00021592268673648583 -1.845651824995021e-6; -0.0001740335263062938 0.0001203947196730665 4.88064162740392e-5 0.00011095076679646979 0.00017167885580726566 4.625688070824944e-5 -7.229379643524088e-5 0.00022416853080664437 4.574954273573164e-5 -0.00011648757008208848 -3.765831006297132e-5 -0.000128918121660374 2.9955512403257842e-5 -3.385045583543478e-5 1.452929628009653e-5 -2.446380476306137e-5 -5.428441527022215e-5 -0.00019623078496499414 -7.47706560983845e-5 3.050128925985443e-5 -0.00014868250381971812 3.882456640291557e-5 -0.00019933534876748164 -5.103009407014916e-5 -4.1465189312141643e-5 -6.865041159165353e-5 5.710939334101491e-6 -0.00011453465940269808 0.0001010869038540246 -7.536927095919248e-5 7.265455409081053e-6 0.0001443955718090221; 0.00017716762596029588 6.958389058928364e-5 4.368284208691347e-5 -0.00014003058385345495 -2.364399114521113e-5 -0.0001513187373267952 -6.687898639132746e-5 5.99332434722592e-6 -0.00023091243128071785 -0.0002269081371115823 -0.00012166979212560146 -0.00025369485155456666 -0.00016142117164343568 -0.0001411733457561931 2.3431105372593243e-5 1.9002797497546224e-5 -0.00018136598299969145 0.00014517119489832546 -7.058405680356396e-5 -0.0001925211756235396 2.916872182033127e-5 -0.00016994045944665738 -4.287305199903633e-5 -4.491905128015097e-5 0.00014492832343312748 -5.769197668195308e-5 -5.3419989118258203e-5 -0.00017384245365172787 -0.00011773472141718716 -3.2599523539095015e-5 7.957269627588798e-5 0.00010367525902336043; -7.679505551482234e-6 -8.909508383060161e-5 1.8555610359327947e-5 9.269271627141452e-5 6.57246049291432e-5 8.250521873431085e-5 0.0001536650737312469 0.00021934454648596936 0.00012938910308600711 -0.00021551962040523693 6.017346036112528e-5 -2.7471646398620395e-5 2.716142021251862e-5 -7.440870162961812e-5 -8.07905459163741e-5 -2.658719008598179e-5 -1.2883964381603781e-5 5.925917807923448e-5 -0.00013621843449931678 -7.801275992226201e-5 2.182086004775016e-5 -4.618184234406946e-5 7.5409378501489626e-6 3.868460090031242e-6 0.0001337415081712402 -3.961246752571614e-5 -0.0001306899145214841 8.768439815919451e-5 0.000129940125908042 0.00012238686699727325 -7.740622154363184e-5 0.0001086870391650877; -8.673906019710338e-5 0.00018752347635283767 3.968471240197499e-5 -9.686793663679094e-6 5.2454589177544246e-5 -9.272184003664449e-5 7.074549312999889e-6 0.000119202394426723 -6.592517610514237e-6 -0.00016212384550188147 -8.183549404664282e-5 5.984254924525834e-5 -2.135341789569593e-5 -8.209005069974994e-5 0.0001355456067639559 0.00010121540502124979 0.00020650365580051803 0.00015145640894017535 -0.00027937306483098235 -2.2083742141219256e-5 -0.00010976799719090003 5.691773797563957e-5 0.00014290138163316663 -3.65734861307619e-5 -8.551283578442663e-5 -9.287657782722569e-5 6.537669133579602e-6 0.00019176998715094887 -0.00010393407616818822 -5.165417430278192e-5 9.460087067892619e-6 2.4478101050514956e-5; 1.5815141317558618e-5 -0.00023796642037852734 -2.8837486958253013e-5 6.9510617991564566e-6 -7.566190308936865e-5 8.856001635266386e-6 8.749168207356209e-5 0.0001138363093319442 9.50944629364455e-6 1.2052022405963286e-5 5.118129208638214e-5 0.0001776042643301178 7.834474106273378e-5 -5.532779347489752e-6 -5.967376112168736e-5 3.1750558023972117e-6 4.279153451859767e-5 -9.530547524848985e-5 8.265258090755885e-5 2.1528299633770826e-5 5.9837251477415894e-5 -0.00014385384404378858 3.4954185842776814e-5 -4.1670149791045045e-5 5.757737542224203e-5 0.00010817872114201002 -3.292323729185384e-5 0.00010469481073739756 7.8664453916254e-5 -9.173692362894299e-5 0.000154562747205612 3.1812623795812584e-5; -0.00021829912430777298 3.4526793106381894e-5 -2.664323575089926e-5 -6.767139849188517e-5 3.237977442975699e-5 -0.00012217208269223076 2.3013697946506245e-5 -8.155802925986612e-5 -8.358757760080174e-5 -4.071940339573315e-5 0.00011267368835459287 4.160273828374659e-5 0.00010923621311147718 -7.180397506964623e-5 -5.220638294955782e-6 5.44906241911128e-6 -3.948310541974448e-5 -0.00011343450160866294 -9.938722095906527e-6 7.3545654677477e-5 2.8625756831649334e-5 4.767107974683338e-5 8.831011517967011e-5 -0.00010601473469225146 -0.0002948321721393946 -0.00010382142637328495 -8.411904992473694e-5 2.7846761886653722e-5 -3.69087988562616e-5 -5.155745885505728e-5 -6.813338595483327e-6 2.8411405301370447e-5; 9.95755324169001e-5 -7.700006122495928e-6 0.00012141097953130835 0.00010109556727397602 0.0002628982253867166 0.00019780974956517725 -0.0001546990900038981 -0.00015218521754434706 0.0001244538504891932 0.00018223818163675212 0.00013589268904467506 -8.454875957399587e-5 -0.00012137609179775002 3.4384207296427584e-5 0.00016546920936377234 -0.00018445241609300544 -9.933942807550238e-5 -1.319939589452103e-5 5.685738845781193e-5 -5.59429647779613e-5 -0.00011769407161921996 -6.715955557013171e-5 5.3418061483505566e-5 0.0001733080059551598 -0.0001279410573857861 -6.016511930795861e-5 -2.1763908321763295e-5 -6.698361563906698e-5 0.0001547315548435649 7.743525905822976e-5 2.6080944493844654e-5 0.0001154287817683723; -1.888107783575562e-5 -0.0001446063494089296 0.0001339581698399758 0.00017908242205034696 2.501216965778678e-5 -6.091441412114132e-5 3.40299533260309e-5 -5.19813826261792e-5 1.2572676837083685e-5 -7.367121236289357e-5 -0.00013787716545753676 -4.973081979674612e-5 6.958527475427311e-7 -9.427784482372395e-5 -5.522378944566045e-5 4.828546786906693e-5 -2.0196272658137778e-5 -2.3949244337341406e-5 2.0601273149301464e-5 0.00011449723876889302 -2.3016631306353194e-6 -0.00012805255630642708 -0.0001965607467585478 0.00013398539647336533 6.92960287548362e-5 5.6220419878302686e-5 -8.657572709613069e-5 0.0002076657159474444 6.7012265730813e-5 -0.00018631347722963015 7.831478922038846e-5 4.9612657107578445e-5; -7.215489778428721e-5 0.00012338577188091926 -2.3953856405713356e-6 4.752370558452108e-5 -4.5944635635180076e-5 0.00015344458632643 7.600829384062697e-5 0.00015268641699111676 0.00012554217654286445 9.584974294301612e-5 -0.00013786681403399814 -0.00017536621191067273 -2.2997471310020664e-5 0.00013416647096635043 5.755553852566406e-6 -0.00019873572920442998 7.381809235556597e-5 -0.00022811483824153995 -0.00015237930812401544 -0.0001231883844546359 9.173729893990391e-5 0.0001311244221917042 6.0601704613134174e-5 8.591763879411514e-5 8.29752633628124e-5 0.00016363600560470124 -4.1866847018183824e-5 3.7396871344014884e-5 5.3895367168648714e-5 -0.0001375219481950008 -6.617027712746981e-5 -0.0001649036758997831; 0.00013673057072054057 -3.261570565911512e-5 4.479185043208207e-5 2.073388604619776e-6 -9.411727989601447e-5 4.3952728792359146e-5 8.719424276639857e-6 5.134928540399752e-5 9.002240074522158e-5 -2.0559896820776012e-5 4.49212018764522e-6 -0.00016216210098748027 -1.1455926706109724e-5 -0.0001950628171656166 0.00014526053962952883 5.918976274159095e-5 -8.182791827522763e-5 0.00024369162963892429 -5.582723280924618e-5 1.7090320743303154e-5 2.4844292401833315e-5 0.00010421956938415322 9.777160847050527e-5 -6.188061126551135e-5 -5.593726711623816e-5 -0.00012998247586466875 3.3919019533353754e-5 0.00014793746629919314 9.190116941652456e-5 -8.738160579042985e-5 1.542620555454204e-5 -0.00017161573816367611; 0.00015694885334506493 0.00011718346033286106 -0.00015725994197144475 5.177857494256155e-5 -8.019406848779188e-5 -9.816738184774558e-6 -2.4402158288451392e-5 -9.508771361742404e-5 4.206089690252221e-5 -0.00017471586650662105 -3.2709815769513865e-5 9.90703503613503e-6 -5.377746381872431e-5 5.558708400626026e-5 4.4657160460778705e-6 -0.00013568949250929033 -2.329280212597119e-5 7.570569252403782e-5 -8.644785592057805e-5 0.0002271239511346801 5.215107850654897e-5 -1.6637400022766658e-5 -4.795619844108778e-6 4.0679123429988894e-5 -6.589102011366103e-5 -8.005072484683438e-5 -2.9670742861519164e-5 0.00013352857897098978 -0.0001543680980657768 0.00011166957230580731 -1.465564653789337e-5 -0.00015323575533419075; -2.9461990498173855e-5 6.85907606292237e-5 -0.00012171110313844753 -0.00011836642545874117 -0.00010853138258441353 0.0001420137513103615 0.00020106360431767177 -6.238095841596836e-5 6.357495555660719e-5 -3.1172300411341065e-7 9.301668876090095e-5 9.964195020969106e-5 -7.138980175129122e-5 5.088416737486603e-6 6.454065703094942e-5 9.166063943651537e-5 -5.082553355365978e-5 6.190419180559276e-6 -6.065108130330074e-5 -3.3009878398086786e-5 -3.385559205541368e-5 -5.0583677084585725e-5 0.00015465155126533482 -1.6870105473725144e-5 0.00010381942672195836 -0.00020086471807205115 -0.00015882890614208768 -5.2144024384841186e-5 0.00014969059963717177 4.367712869664652e-5 -9.40668370550774e-5 -6.482808851630294e-5], bias = [3.148647374105815e-11; 7.782062782997473e-10; 1.1147615500485922e-10; -1.7282165854214769e-9; -1.7451345738047184e-9; 1.3942196925497262e-9; -4.5059137250715835e-9; 4.62745028058986e-10; 3.81620406495497e-9; -5.876530370169433e-9; -4.285004776743963e-9; -1.4120944632652934e-9; -2.3329676913608998e-9; 4.545872923471428e-9; 7.863317584909448e-10; 2.9426054476653893e-9; 1.87155949060896e-9; 2.1734275782637996e-9; 1.6033731729101584e-9; 1.1197995286382597e-9; -1.8024335376246115e-9; -7.3843613750914415e-9; 1.9275381215125304e-9; 6.573199455661923e-10; 2.1116407690391642e-9; -3.501108694511685e-9; 2.97562507672343e-9; -3.880135023396015e-10; 4.901132234952043e-10; 7.752348065763573e-10; -1.235804113668911e-9; -1.6473692849953148e-10;;]), layer_4 = (weight = [-0.0006625376857563734 -0.0005814426794621149 -0.0008052763232868552 -0.0007138872246424775 -0.0008827990399737088 -0.0006064185028924932 -0.0006470058640809428 -0.0007323891003113015 -0.0006256381261162649 -0.0005428546834979917 -0.0008503596128407315 -0.0005967740165607107 -0.0005620010923579409 -0.0006983899520078766 -0.0008345212419038957 -0.0007098534788415799 -0.0006869014432223204 -0.0006031133898870843 -0.0007152961528497986 -0.0007234157043771161 -0.0008821166525540704 -0.0005233727824305787 -0.0005786382980549659 -0.0007231973819613111 -0.0006251769944036685 -0.0008337118241811358 -0.0005923757676575531 -0.0007793351010349525 -0.0007972836786549282 -0.0008062225257684348 -0.0006858039139906729 -0.0005904006181463546; 0.0001781405303117125 0.00023640993616872507 0.00040859169199097253 6.636655283375594e-5 0.00016386353877531542 0.00014247964766641788 0.0003867291084823789 0.0004489774752600957 0.00021840958318790696 0.00034924118443328757 0.00035654608348826026 0.00011119809373355456 0.00030058304580772067 0.00024517769848903404 0.0003259752912818825 0.0002361434088148043 0.0003824969095395633 0.00016958066736768066 0.000226920500809224 5.214944612904287e-5 0.00012345451309021875 0.00011360261232219211 0.0002474988591258941 0.00016642859641297745 0.00015009743355522479 0.00017679672660977694 0.00029731941198558843 0.00022027339510446057 5.806169291088688e-5 0.00035771100312097416 0.0002098680053185889 0.0003400498355848862], bias = [-0.0006927223160828577; 0.00022820321143064714;;]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(compute_waveform(
    dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(ax, tsteps, waveform_nn_trained; marker=:circle,
        alpha=0.5, strokewidth=2, markersize=12)

    axislegend(ax, [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb)

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(LuxCUDA) && CUDA.functional()
    println()
    CUDA.versioninfo()
end

if @isdefined(LuxAMDGPU) && LuxAMDGPU.functional()
    println()
    AMDGPU.versioninfo()
end
Julia Version 1.10.3
Commit 0b4590a5507 (2024-04-30 10:59 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 8 default, 0 interactive, 4 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_AMDGPU_LOGGING_ENABLED = true
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 8
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.5, artifact installation
CUDA driver 12.5
NVIDIA driver 550.54.15, originally for CUDA 12.4

CUDA libraries: 
- CUBLAS: 12.5.2
- CURAND: 10.3.6
- CUFFT: 11.2.3
- CUSOLVER: 11.6.2
- CUSPARSE: 12.4.1
- CUPTI: 23.0.0
- NVML: 12.0.0+550.54.15

Julia packages: 
- CUDA: 5.4.2
- CUDA_Driver_jll: 0.9.0+0
- CUDA_Runtime_jll: 0.14.0+1

Toolchain:
- Julia: 1.10.3
- LLVM: 15.0.7

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.623 GiB / 4.750 GiB available)
┌ Warning: LuxAMDGPU is loaded but the AMDGPU is not functional.
└ @ LuxAMDGPU ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/LuxAMDGPU/DaGeB/src/LuxAMDGPU.jl:19

This page was generated using Literate.jl.