Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux, ComponentArrays, LineSearches, OrdinaryDiffEq, Optimization, OptimizationOptimJL,
      Printf, Random, SciMLSensitivity
using CairoMakie

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end
one2two (generic function with 1 method)

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end
soln2orbit (generic function with 2 methods)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end
d_dt (generic function with 1 method)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0)<1e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio1 "mass_ratio must be <= 1"
    @assert mass_ratio0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end
compute_waveform (generic function with 2 methods)

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defiing a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(Base.Fix1(broadcast, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1e-4)),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1e-4)),
    Dense(32 => 2; init_weight=truncated_normal(; std=1e-4)))
ps, st = Lux.setup(Xoshiro(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[6.6231034f-5; 8.040741f-5; 2.4146846f-5; 9.728673f-5; -7.738072f-5; -1.3412722f-5; 5.5154098f-5; 0.00011454331; -0.00011787867; -2.3195911f-5; 0.00019253469; 1.189833f-5; -0.000102975646; 2.5933896f-5; 9.880327f-5; 5.30923f-5; 3.1889478f-5; 0.00012094199; 1.6285752f-5; -0.00017049465; -6.519985f-5; 0.00010549428; 0.00012267953; -2.341447f-5; -2.5151798f-5; -5.718582f-5; -8.736814f-5; 1.6228023f-5; -4.9056092f-5; -0.00011611969; -0.00018902734; -0.00017749246;;], bias = Float32[0.3179388, -0.5745903, -0.099388, 0.16789341, -0.5776521, 0.3880942, -0.11283982, -0.13419473, -0.04003322, 0.2762102, -0.5337069, 0.25753474, -0.38138282, -0.2581538, 0.11033511, 0.9027512, -0.9034306, 0.25934494, 0.5407522, -0.79059625, 0.34238505, -0.89984524, -0.671723, 0.13847423, -0.5991702, -0.45280135, -0.5255486, -0.73688066, -0.53758705, -0.7259551, 0.9868355, -0.85511625]), layer_3 = (weight = Float32[-3.0663206f-5 -0.00016208204 0.00012403492 5.1166542f-5 4.518453f-5 8.569371f-5 -7.148226f-5 -0.0001133331 4.0099003f-5 -2.8557088f-5 -0.00017484758 -5.276063f-5 -8.4613916f-5 -5.6824203f-5 -1.5841306f-5 3.5167224f-5 -8.977989f-5 -8.805507f-5 5.8567734f-5 -0.00011041424 -9.166553f-5 -0.00013461045 -0.00013812866 -0.00012434364 2.252775f-5 9.493882f-5 5.235604f-5 -3.9504346f-5 4.8868515f-5 6.405822f-5 -6.7983245f-5 0.00013040728; 0.00016163234 6.3927423f-6 -3.7555168f-5 5.0979335f-5 -5.6953566f-5 -8.026777f-5 1.5825965f-5 -0.0003323877 -0.00012743418 2.8482646f-5 9.502375f-5 -8.803089f-5 -1.6042032f-6 3.2302707f-6 -0.00020780506 -1.5358116f-6 -2.7711615f-6 0.00025079606 -4.3316475f-5 -4.4407298f-5 1.2994173f-5 5.862348f-5 -2.9825002f-5 -0.00010641367 6.354025f-5 6.736677f-5 -2.6056592f-5 -5.2734304f-5 4.7542482f-5 -3.3631444f-8 3.421057f-5 -3.8598988f-5; 5.44328f-5 6.471798f-5 4.276385f-5 -2.850256f-5 -3.7751327f-6 -0.00014484886 0.00023496657 -1.542162f-5 2.0178257f-5 4.573752f-6 0.00010298057 -8.516785f-5 4.7842317f-5 -0.00016773434 4.8089994f-5 0.0001415731 0.00016199729 9.928568f-5 8.214988f-5 0.00012435422 -0.00010726531 0.00025105182 -0.00014175063 -8.954718f-5 -0.000120581346 0.00010539259 -1.8787523f-5 -0.000101238635 -4.080401f-5 5.8521677f-5 8.495256f-5 0.00013503202; -0.00013757273 -0.00018927536 -3.250951f-5 -9.276694f-5 0.00013743943 -0.00011121622 -0.0001250889 -0.00011172029 -6.3228035f-5 0.00014294795 6.2890496f-5 0.00010335798 -0.00016103551 6.3185005f-5 -4.882303f-5 -2.2549253f-5 -6.4775115f-5 -5.0420454f-5 -9.5179435f-5 -1.6837404f-5 0.00015117976 3.348035f-5 3.5240253f-5 -0.0001612498 -5.195452f-5 0.0002002866 -1.7100081f-5 2.342499f-5 5.9728754f-5 -7.312719f-5 4.086078f-5 -9.338634f-5; -7.035169f-5 -0.00016396336 5.324751f-5 3.2371438f-5 9.2591086f-5 -0.00011410086 8.62952f-5 -7.955446f-5 0.00012342524 1.6911349f-7 -0.000121008125 -6.28906f-5 0.000121481644 7.0602546f-5 -0.00012087387 -4.3850694f-5 -2.1862605f-5 -6.524176f-5 -2.3515788f-6 0.00018759219 -0.00016656004 -0.000152812 2.1746475f-5 1.4451544f-5 -8.709109f-6 -9.999444f-5 2.9582352f-5 -6.5673696f-5 -7.889765f-5 0.000123362 -1.6373084f-5 5.240089f-5; 3.215737f-5 0.00017174211 -6.436769f-5 9.610916f-5 7.7247554f-5 6.5564345f-5 -0.00022910685 -3.5679896f-5 6.555541f-5 9.259076f-5 -7.4275915f-5 -0.00013147062 0.00014524689 -2.0252568f-5 -4.7311474f-5 -0.0001770485 -5.9182305f-5 -0.00011697365 -5.5630582f-5 -7.347584f-5 -4.8541977f-5 -2.9916984f-5 2.5520421f-5 0.00014857027 5.1351076f-6 2.329112f-5 4.2443848f-5 -3.1856594f-5 7.654836f-5 3.2261134f-5 2.7745487f-5 -8.3628416f-5; -0.00010835042 -3.6862457f-5 0.00019454717 3.2303775f-5 -4.3214837f-5 -0.00010914766 -4.1139858f-5 -2.2879144f-6 -8.886689f-5 -1.7083088f-5 0.00016147768 6.0316157f-5 -7.065938f-5 0.00012093111 -0.00021048042 -0.00022622512 3.6913614f-5 4.8041613f-5 -0.000116251555 0.00017233324 8.4290295f-5 -9.464381f-6 5.5637094f-5 -8.992671f-5 2.3078921f-5 -0.00010308122 -0.0002191234 -0.00017911215 4.9335117f-6 0.0001059991 -0.000116434596 0.000117869306; -0.00017514505 9.779926f-5 0.00011858776 -4.7271245f-5 4.9897164f-5 -0.00015960626 -0.00015447836 6.13955f-5 -0.00016223296 5.6464773f-6 -9.9826975f-6 0.00013951071 -0.000117319236 -9.023217f-5 -1.0128492f-5 -0.00018665727 5.7356036f-5 8.521598f-5 -6.846402f-5 5.3232463f-5 -7.14397f-5 -6.4714906f-5 -9.71375f-5 -0.00010063836 -8.514572f-5 4.2472162f-5 -0.00012559854 -6.437742f-5 -2.961927f-5 -9.2150985f-5 7.8977835f-5 9.754822f-6; 9.002284f-5 -9.5824646f-5 9.974843f-5 0.00018277709 -6.1394116f-5 -1.9526358f-5 5.043118f-5 -4.764175f-5 -5.191583f-5 -5.0789124f-5 -0.000117663985 -1.9604824f-5 -4.3107262f-5 7.499624f-5 6.2680556f-5 0.00012511903 -0.00011079951 -1.2293222f-5 0.00020400797 -7.681027f-5 -6.365785f-5 -0.0002171678 0.00011371447 8.638166f-5 3.8846407f-5 7.7950666f-5 -1.451952f-5 9.1437305f-6 4.4984186f-5 4.571274f-5 -7.943672f-5 -0.0001529736; -0.00024876345 -6.9473164f-5 0.00010514772 -0.0001058871 -2.4650059f-5 -0.00011358361 3.6604433f-5 -9.4918796f-5 3.926404f-5 1.8800343f-5 -1.2617714f-5 -0.0002101285 9.516783f-5 -9.7606484f-5 -1.5120196f-5 -0.00017524118 0.00012285027 9.5343494f-5 -0.00010953761 -9.781375f-5 -0.00010382624 8.059813f-5 4.384923f-5 -3.4456847f-5 1.9892257f-5 0.00014372486 4.0145118f-5 4.6534402f-5 5.451406f-5 1.2859452f-5 5.1909792f-5 0.00020366587; -0.00011216841 -9.352545f-5 1.0820325f-5 0.00011497941 6.333415f-5 -2.9585382f-5 -5.484149f-7 5.8387537f-5 3.676283f-5 -3.178475f-6 6.4561165f-5 0.00012109539 -0.00020504805 3.482308f-5 -1.285915f-5 0.00010376846 9.380684f-5 -0.00012004763 -5.6934015f-5 1.0675703f-5 0.0002460085 0.000117933516 4.9786315f-5 3.638063f-6 -1.3785428f-5 4.576846f-5 -0.00012773846 3.9342955f-5 5.677768f-6 -0.00011390564 3.036621f-5 5.23301f-6; 2.1621508f-5 -6.6034976f-5 -0.00010154794 -3.7421458f-5 -9.906072f-5 -6.8961635f-5 0.00016148566 5.5673238f-5 -0.00019579541 4.0188184f-5 4.682844f-5 0.00014438668 -1.1681819f-5 3.1048352f-5 -1.579489f-5 -0.00012762795 4.5579545f-5 -8.2682374f-5 1.7668644f-5 -3.0990755f-5 1.4805749f-5 -8.6046064f-5 -6.0001366f-5 6.429432f-5 -8.51449f-7 -0.0001178722 1.9084171f-5 0.00011041435 4.727076f-5 2.776336f-5 -9.4547904f-5 -7.4250083f-6; -2.1574311f-5 -2.6311769f-5 -0.00014562426 -0.00010406075 4.44243f-5 7.254354f-5 3.0596097f-5 2.960029f-6 9.371278f-5 -6.0591552f-5 -0.00019529452 6.148526f-5 0.00013258564 -0.00016935916 6.5750806f-5 1.587518f-5 7.13107f-5 -2.3676368f-5 0.00012572233 7.7386685f-5 -6.217161f-5 -5.730356f-6 -0.00017914375 -9.6195436f-5 6.381152f-5 -2.323391f-5 -5.7629215f-5 0.00012693802 -0.0001964098 -6.176497f-5 0.00024618852 -0.00019073607; -6.973166f-5 0.00016393002 -0.00010235659 0.00014490927 -6.6111155f-5 -0.00013216003 3.0489304f-5 -0.00016050802 7.238626f-5 7.409247f-6 6.602712f-5 -1.0795995f-5 -0.00016162633 -1.0580473f-5 -3.25692f-5 -5.2364387f-5 -0.00017188724 -3.2197586f-5 6.3966356f-5 0.0001643777 -1.0848956f-5 -0.00020082433 3.631001f-5 3.1758027f-5 -2.888522f-5 2.9797477f-5 0.00017421745 0.00012714171 -8.191415f-5 6.970787f-5 7.1463735f-5 0.00013703413; -7.57f-5 0.0001050237 -1.3470017f-6 9.8526354f-5 -7.248131f-5 -7.234537f-5 -2.1228234f-5 4.0759474f-5 6.582407f-5 -7.091985f-5 3.0439645f-5 7.747634f-5 3.9947012f-5 7.9223755f-5 -5.4228865f-5 8.95945f-5 0.00013370489 0.00011439464 -0.00018512856 2.8345792f-6 -8.401538f-6 -7.74557f-6 0.00010549588 6.8559325f-6 7.4068885f-5 -5.5730095f-5 1.9426823f-5 -0.00017407213 4.8742957f-5 0.00011683885 -4.1225057f-6 2.7556838f-5; 6.318239f-5 -5.327867f-7 6.385541f-5 -0.00016248618 7.975426f-5 -0.00019114217 -5.9078953f-5 5.6689846f-6 -1.8747243f-5 -6.302776f-5 -0.00012834353 6.798005f-8 0.00012352174 0.00013674838 -3.8609076f-5 0.00010451606 -0.00010724887 -0.000100645746 -6.829556f-5 5.5489138f-5 0.00019383074 -0.0002896118 8.8063636f-5 -1.8685254f-5 -0.00017268461 4.659745f-5 -0.00010447737 1.4620114f-5 0.0001423945 0.00012934979 -0.00021476521 7.98377f-5; -7.3361094f-5 -3.550534f-5 -8.465106f-5 -7.9540194f-5 5.4757835f-5 1.068054f-6 8.671901f-5 -1.7326867f-5 4.4279327f-6 -2.0321138f-5 -7.647649f-5 -7.291174f-5 1.5260528f-5 5.0832863f-5 0.00014007006 -4.2062406f-5 -0.00010452727 -8.181587f-5 -6.0980943f-5 -5.7514186f-5 -4.6058027f-5 -7.294564f-5 1.2380104f-5 0.00017351576 -6.152009f-5 0.00020497704 -8.0569f-5 9.55104f-6 -0.00023729743 -0.000106603795 3.569067f-6 1.3889042f-5; -6.074308f-5 1.4326351f-5 -2.668652f-5 1.960117f-6 -0.00010868177 1.617436f-5 0.00012995963 4.3177984f-5 -3.8903592f-5 -0.00017106754 2.2739052f-5 -0.00012905938 -0.00013590523 0.0002607076 -0.0001242622 -4.99148f-5 5.834399f-5 -6.3862444f-6 -8.2706916f-5 -8.4253925f-6 3.006717f-5 -8.3432125f-5 0.0002611432 -5.679418f-5 -5.1571908f-5 -7.82655f-5 0.000119659664 -6.639714f-5 0.000104058665 -5.73203f-5 0.00011982925 3.8717775f-5; -3.130982f-5 -7.180091f-5 6.296856f-5 3.188651f-5 -0.00020825498 1.5046344f-5 -0.0001846267 -6.167479f-5 0.000106254156 8.564353f-5 6.847113f-5 -9.625674f-5 -3.6753358f-5 -4.10619f-5 1.5229859f-5 6.7603716f-5 -2.7237786f-5 -1.6027952f-5 7.372334f-5 -3.094069f-5 -4.394532f-5 -6.1851497f-6 -6.285695f-5 -0.00011073665 -7.768898f-5 1.9851403f-5 -7.89314f-6 8.793458f-5 0.00017658004 -2.7939132f-5 -0.00011810128 -1.8032266f-5; 0.00019869678 -1.1987848f-5 0.00017685603 8.110484f-5 0.00010564396 -5.5484386f-5 -5.3676147f-5 8.3773135f-5 8.7041924f-5 -2.1924474f-5 -7.2476178f-6 -0.00012712278 4.483084f-5 3.7625825f-5 -1.9618652f-5 -0.00023658806 5.1477982f-5 0.000176413 5.0713705f-5 -0.00021392651 -0.00019777489 -0.00014323773 -0.00010341145 0.00010857606 -6.28547f-5 0.00025170553 -8.569101f-5 0.00019514476 -1.4831382f-5 4.4961627f-5 -6.146721f-5 -8.1853446f-5; 1.4746627f-5 -1.8582477f-6 -4.8146223f-5 8.5420994f-5 -7.3142066f-5 -9.958538f-5 5.8420424f-5 9.904891f-6 -2.608333f-5 0.00017538138 -3.232115f-5 -8.8590205f-5 5.3522293f-5 7.389571f-5 0.00015304316 3.6246645f-5 -0.00012947418 -8.602696f-5 2.3602794f-5 8.91458f-5 1.38626265f-5 3.1969346f-5 -5.7063793f-5 -3.4417026f-5 -1.0618861f-5 -0.00010106251 2.6949625f-5 2.8241904f-5 0.00014759772 -0.00014103223 0.0001184898 -0.0001395846; -0.00010115728 2.0939999f-5 -6.505542f-5 6.636322f-5 2.815497f-5 -0.00013346225 -5.688925f-5 0.00015346093 5.177969f-6 -5.2735184f-5 6.9556445f-5 -1.7579265f-5 -6.4488566f-5 4.535487f-5 1.3689277f-5 -7.890483f-5 -3.7868234f-5 6.5487184f-5 -2.6990112f-5 4.0644136f-5 4.7402762f-5 -0.000105079685 -1.0512115f-5 3.0780157f-5 6.861355f-5 0.000112223854 -3.6515383f-5 -0.00013350062 1.8207536f-5 -8.354349f-5 1.8156896f-6 -1.858352f-5; 0.00013606566 5.9400165f-5 3.2240045f-5 0.00013480308 -3.2215124f-7 -4.771231f-5 1.4230568f-6 6.1771374f-5 3.1222764f-5 -2.1518015f-6 0.00020311416 2.2963397f-5 5.317318f-6 3.9302387f-5 8.032976f-5 -0.00022411543 -1.8543697f-5 -2.1121341f-5 -1.34163265f-5 1.1960672f-5 -7.757318f-6 -1.4184697f-6 2.4919213f-5 -8.083392f-5 1.9215295f-5 -2.0364545f-5 3.5972778f-6 -4.5137673f-5 -7.646713f-5 0.000105118095 2.8974304f-5 -0.00011276115; -8.255063f-5 4.146117f-5 -0.00019277561 4.5782064f-5 -6.8325615f-5 6.131696f-5 -5.3471656f-5 -6.660153f-5 0.0001479274 -0.00015370877 3.3602522f-5 2.965963f-6 8.282202f-5 -0.000103710234 -0.00020926961 2.0165313f-5 1.5897242f-6 -4.989588f-5 9.976803f-5 4.8675127f-5 7.434739f-5 -1.728463f-5 0.00011013572 -3.2935284f-5 0.00013103097 6.293201f-6 7.629341f-5 7.893281f-5 -0.00010510553 -0.00023015 -6.3807296f-5 0.00013406212; 0.00012888866 5.8962654f-5 -0.0003278351 -1.1387809f-5 0.00015898977 -9.423247f-5 -1.5929818f-5 0.000104059036 2.4859326f-5 3.6697842f-5 -0.0001050385 -7.783747f-5 0.00011095857 0.0001531535 8.095514f-6 0.000100635356 -0.00017611028 -0.00022151659 -4.2956355f-5 4.4743345f-5 -3.6020498f-5 -9.0944384f-5 -2.4893583f-5 1.3325803f-5 7.4777694f-5 0.00021454651 0.00015494044 -0.00018257144 5.2719526f-5 1.5875576f-6 3.8527764f-6 0.00011455146; -1.0510612f-5 -7.9507896f-5 -1.716159f-5 5.0922972f-5 -7.6525626f-5 -0.00019440807 0.00010505835 -2.2247383f-5 -3.3333745f-5 -7.03226f-5 -0.00013541193 -9.046142f-5 4.9246097f-5 -6.85296f-5 -5.9404705f-5 0.00013898716 9.614508f-5 -3.584154f-5 -7.6867895f-5 1.3264858f-5 -8.698802f-5 -1.1008336f-5 0.00014101244 -0.000119909724 1.6569201f-8 -8.309715f-5 6.6399734f-5 -0.00022338086 -0.000106179956 -9.557949f-5 0.000107130814 -6.687888f-5; 5.3214197f-5 0.00020937118 4.1943513f-5 7.755596f-5 5.1607632f-5 0.00010474593 6.9160946f-5 1.661635f-5 -7.300738f-6 -0.0001564612 -0.00012868326 -0.00016819646 6.2911495f-5 7.3680836f-5 9.485799f-5 -9.577613f-5 0.00010897742 -2.8192138f-5 -3.5020814f-5 1.4728594f-6 -7.347124f-5 -0.00010769834 9.9259596f-5 0.00016536932 -9.312229f-6 -5.872265f-5 6.1536844f-6 5.504552f-5 0.00016509808 0.00016365373 -7.198006f-5 0.00012085911; 8.185449f-5 -0.00025090444 0.00016438737 3.68499f-5 -0.00013197708 7.4154974f-5 -7.256876f-6 -9.543281f-5 -9.1146813f-7 -7.585663f-5 -8.4494957f-7 -1.5517666f-5 8.3332416f-5 -1.00128045f-5 -7.135061f-5 5.1770567f-5 -5.2280306f-5 -3.696867f-5 3.826118f-6 -3.3374035f-5 6.479953f-5 1.9057024f-5 1.9222587f-5 -8.121512f-5 5.171659f-5 3.638576f-5 -2.0906805f-6 -2.4223042f-5 -6.831937f-5 -0.00010120198 0.00016518368 1.2419263f-5; 4.782592f-5 0.00025389812 2.9272249f-5 5.241387f-5 0.00011946306 7.294716f-6 -6.277009f-5 5.007782f-5 6.9063026f-5 -5.2962896f-5 8.06442f-5 -1.8251754f-5 0.0002212222 2.9873003f-5 0.000101013895 -0.00015619757 -1.8870787f-5 2.0554508f-5 -7.2919356f-5 4.421159f-5 -9.025333f-5 2.1787566f-5 6.477845f-5 7.849543f-6 0.000118529024 -0.000100653415 0.00010128701 -0.00013492907 0.00021828087 7.255126f-5 -0.0001778174 3.1104297f-5; -2.9297704f-5 0.00015416983 0.00021400537 4.300042f-5 3.1535412f-5 0.000111028225 0.00010752696 0.00014255711 -6.24829f-5 -2.4159968f-5 -6.577889f-5 7.223851f-5 2.5814665f-5 -0.00011020142 0.00013548916 7.333311f-5 -0.00011637696 6.633607f-5 3.548159f-5 4.5124056f-5 4.899363f-7 -0.00022627885 8.7407f-5 -4.4922563f-6 0.0001603375 -5.692261f-5 -0.00015505812 -4.311619f-5 -9.957678f-5 4.5114153f-5 0.0001277764 4.0827654f-5; -0.00018884602 -8.252291f-5 -0.00020275004 -1.5996142f-5 1.0948557f-5 -1.2531215f-6 2.1226431f-5 -0.00013950102 -6.244997f-5 6.189097f-5 8.317054f-6 0.00011069288 -6.0612667f-5 0.00023008672 -8.424004f-5 0.00011367997 7.365242f-5 8.5696796f-5 -2.6832919f-5 -4.5422978f-5 0.00011473443 3.4665725f-5 1.8310264f-5 6.8827154f-5 -2.7325694f-5 -6.988815f-5 4.289541f-5 -8.405043f-5 5.8925f-5 -2.1137472f-5 5.0164148f-5 -0.00011784828; 0.0001533554 0.00015822034 -8.252598f-5 -0.00018769057 3.7338465f-5 -6.259479f-5 2.4409404f-5 6.757417f-5 4.258912f-5 2.2142702f-5 -5.396056f-5 5.1712414f-5 -0.00011817409 6.020662f-5 0.00011015203 0.00010924333 9.301801f-5 3.6018584f-5 -0.00013172756 8.561092f-5 7.678158f-5 0.00011371709 -2.76895f-5 -7.738934f-5 5.9298916f-5 4.723197f-5 -0.00017537427 -0.00022233166 2.6152233f-5 -1.5492504f-5 -0.000110231376 -3.0445044f-5], bias = Float32[-0.05889702, -0.041791085, -0.1648654, -0.15569155, 0.06983162, -0.12103028, 0.15167998, 0.055664733, 0.0952091, 0.07429686, -0.1276857, -0.07680357, -0.039516486, -0.15966807, 0.110387124, -0.1681824, -0.014540031, 0.12944, -0.044445623, -0.08000397, -0.06320228, 0.1597552, -0.07586683, -0.08197452, -0.010823796, 0.06522438, -0.15413529, -0.056864467, 0.13072613, -0.1700964, 0.017366251, -0.17152534]), layer_4 = (weight = Float32[-5.104401f-5 1.430771f-5 9.083567f-5 9.083542f-5 0.00012574377 -2.4576737f-5 4.5244182f-5 -5.4398597f-6 0.00010451808 0.00010085061 -3.2676508f-5 -2.9209876f-5 0.00012008122 -6.3496635f-5 -2.9200675f-5 9.803897f-6 0.00011898517 -6.5927146f-5 -3.8370963f-5 1.9270436f-5 0.00013988838 2.2800106f-5 -3.268338f-5 -0.00024198752 0.00015849457 3.9682065f-5 1.4435376f-5 6.920682f-5 3.796451f-5 6.1765364f-5 -5.4829612f-5 -0.00014768887; -0.00017179378 -0.00016873606 6.418734f-5 9.887566f-6 4.863643f-5 -3.2233434f-5 -0.00011780128 6.4042615f-5 -0.00013209275 0.00012947312 -0.00016594848 0.0001316221 7.2413015f-5 0.00012479995 -0.00011841402 5.2480213f-5 2.3162862f-5 -0.0001412326 -3.3989803f-5 3.4169756f-5 2.4569532f-5 7.078854f-5 7.5379605f-5 6.353248f-5 -6.7985246f-5 3.274739f-5 0.00018627595 -0.00015407652 5.046163f-5 0.00013702798 -8.360035f-5 -6.6784574f-5], bias = Float32[0.15760058, -0.15231954])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most DL frameworks, Lux defaults to using Float32, however, in this case we need Float64

julia
const params = ComponentArray{Float64}(ps)

const nn_model = StatefulLuxLayer{true}(nn, nothing, st)
StatefulLuxLayer{true}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(broadcast), typeof(cos)}(broadcast, cos)),
        layer_2 = Dense(1 => 32, cos),  # 64 parameters
        layer_3 = Dense(32 => 32, cos),  # 1_056 parameters
        layer_4 = Dense(32 => 2),       # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)

    axislegend(ax, [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"]; position=:lb)

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(waveform, pred_waveform), pred_waveform
end
loss (generic function with 1 method)

Warmup the loss function

julia
loss(params)
(0.001823082651684008, [-0.02243847138556246, -0.02176213391238282, -0.02108579643920279, -0.019935334875737478, -0.018274306233959488, -0.016048137749408192, -0.013181859165921751, -0.009573118040414046, -0.005090213019040884, 0.0004348685088091308, 0.007201264692079408, 0.015408762149857905, 0.025159235621670474, 0.03612052861420048, 0.046669866145676496, 0.052139594912548944, 0.04422174684392486, 0.018771162384050365, -0.014516301717082404, -0.03941308389786896, -0.0477644923446369, -0.04428036143122951, -0.03603034292244731, -0.02688541521201509, -0.018336677733871848, -0.010811786592801282, -0.00434640023459461, 0.0011508712893751995, 0.005799667584407614, 0.009710258666963024, 0.012976645984354853, 0.0156726129964267, 0.017854428346082765, 0.019562242152695936, 0.02082213431876141, 0.021645928304623314, 0.02203334334237635, 0.02196995194897431, 0.02142647796083125, 0.020356037545118576, 0.01869061480664145, 0.016334348554324137, 0.013159356666558635, 0.008988355941855491, 0.0035916506965478563, -0.003322398212668652, -0.012092745625556957, -0.022987847557885972, -0.03581897437295034, -0.04878892818671767, -0.05623350930212631, -0.0486414344129684, -0.022826084036312956, 0.010173625846362385, 0.03432698351544518, 0.043199315629050244, 0.04163379602807581, 0.03553844716687772, 0.02815045863297335, 0.020839335680732898, 0.014101091795419752, 0.008072168207013192, 0.0027470708842234843, -0.0019249492977777363, -0.006006772175868256, -0.009556359324865881, -0.012622316481317043, -0.015242274420114298, -0.01744413694831037, -0.019240316233852195, -0.020638675348980205, -0.021628580565203632, -0.02219055844484452, -0.02228769084489577, -0.02186273164192886, -0.020837406079960836, -0.01909376666170443, -0.0164792157416331, -0.01276753071840482, -0.007667590730442213, -0.000768446142069001, 0.008420193839846596, 0.020394115349847797, 0.03513788612392903, 0.050616664825398984, 0.05993652952552253, 0.05254614765594502, 0.026566365430708627, -0.0058236321423043285, -0.029077828621218067, -0.0385102694485441, -0.038740032495551746, -0.034642060538337155, -0.0289374679022021, -0.022884284420241433, -0.017014536421049328, -0.01153977779129728, -0.006517909915785931, -0.001958419718507638, 0.002167460832523282, 0.00587899593965484, 0.00921021693360408, 0.012174561266787228, 0.014792993210590738, 0.017065637774157802, 0.01899445755467269, 0.020561858268345376, 0.02174407689495545, 0.022497197785519216, 0.022757805832431463, 0.022435716822495837, 0.021397797651348412, 0.01946261795684119, 0.016361418986387658, 0.011726337916502386, 0.005024545876165095, -0.004427529681491002, -0.01740101529294733, -0.034090822484633236, -0.052123093832495894, -0.06311427179017694, -0.05582530037968793, -0.02988124624364923, 0.001516165795070679, 0.023769956505064794, 0.03374870934209453, 0.0356088629922778, 0.03334085794449987, 0.029235811744213542, 0.024439160598101507, 0.019501606024774434, 0.01467478632631724, 0.010077969675182244, 0.005749542159663335, 0.0017097818049346377, -0.0020497734121972433, -0.005533420724733165, -0.008741543163333611, -0.011688399146157864, -0.014364877866680762, -0.01676485353660624, -0.018866803045577363, -0.0206464991403388, -0.022053723985184882, -0.02302084802003151, -0.023450823615624724, -0.02319725159378905, -0.022051989065500526, -0.01970575006379582, -0.015712368974856342, -0.009400501490741236, 0.00015139121376631247, 0.014032771046076448, 0.032698757377777034, 0.053252735580295966, 0.06562811994021295, 0.05837368389376742, 0.032717872026685545, 0.002695544099220877, -0.018490304077370463, -0.028951339788730972, -0.03225308607982348, -0.03164074709324329, -0.02903714906751419, -0.02547875426341768, -0.02151082573680176, -0.017412996146115228, -0.013334165199427213, -0.009357853418759766, -0.005513206871854576, -0.001825678085930179, 0.0016986389400696477, 0.005047199281464833, 0.008225769957995956, 0.011219310585280409, 0.014018088101677265, 0.01659733761697711, 0.018933156427858113, 0.020973785464184914, 0.022650843317046778, 0.023862478334226243, 0.024455600164091235, 0.024201357147538453, 0.022749408232861718, 0.019574255159918237, 0.013848616172830407, 0.0043816722940055374, -0.010319987290358152, -0.030980398473973606, -0.05393409841338552, -0.06734995874385699, -0.060113645354089584, -0.03505934847536685, -0.006770107274433367, 0.013309286782019694, 0.02414782822401491, 0.028691850989135348, 0.029551530065104227, 0.028340017921065778, 0.025982197504783734, 0.023000136616728596, 0.019686676877581008, 0.016213108910504637, 0.012676580495157273, 0.009141723650615002, 0.005634210450743631, 0.0021813277267197363, -0.0012056744878912596, -0.0045083659540542725, -0.007726150084805432, -0.010837547154770985, -0.013825376516154033, -0.01665575824119766, -0.01929185649382079, -0.021662282644756545, -0.023667096509988307, -0.025151955830090264, -0.025873953113133084, -0.0254500751924619, -0.02326350699826509, -0.018335139836177558, -0.00913821732400304, 0.006297747578256574, 0.028955601867777486, 0.05406801175706173, 0.06818679042113111, 0.061005426450113996, 0.03693216759885086, 0.010663479530652909, -0.008283391550027858, -0.01936853853330215, -0.024947454584347904, -0.027088897861336034, -0.027149507086643498, -0.025935463029689944, -0.023934146149450573, -0.021440815518917456, -0.018633177507857774, -0.015624672751902649, -0.012484763742314563, -0.009265138881991397, -0.005986598967060377, -0.0026752756192732923, 0.0006571007800926081, 0.003990606124488848, 0.0073226008345873336, 0.010628811274770365, 0.013887567618086521, 0.017056841192097695, 0.02022611476610902])

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l, pred_waveform)
    push!(losses, l)
    @printf "Training %10s Iteration: %5d %10s Loss: %.10f\n" "" length(losses) "" l
    return false
end
callback (generic function with 1 method)

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob, BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback, maxiters=1000)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [6.623103399761021e-5; 8.04074079496786e-5; 2.414684604445938e-5; 9.728672739583999e-5; -7.738071872154251e-5; -1.341272218269296e-5; 5.515409793588333e-5; 0.0001145433125202544; -0.000117878669698257; -2.319591112609487e-5; 0.0001925346878124401; 1.1898329830728471e-5; -0.00010297564585926011; 2.5933895813068375e-5; 9.880326979327947e-5; 5.309229891281575e-5; 3.188947812304832e-5; 0.0001209419933729805; 1.6285752280964516e-5; -0.00017049464804586023; -6.519984890474007e-5; 0.00010549427679507062; 0.00012267952843103558; -2.341446997888852e-5; -2.5151797672151588e-5; -5.718582178815268e-5; -8.736814197618514e-5; 1.622802301426418e-5; -4.905609239358455e-5; -0.00011611969239311293; -0.00018902734154835343; -0.00017749246035236865;;], bias = [0.31793880462646484, -0.5745903253555298, -0.0993880033493042, 0.1678934097290039, -0.577652096748352, 0.3880941867828369, -0.11283981800079346, -0.1341947317123413, -0.04003322124481201, 0.2762101888656616, -0.5337069034576416, 0.2575347423553467, -0.3813828229904175, -0.2581537961959839, 0.11033511161804199, 0.9027512073516846, -0.9034305810928345, 0.2593449354171753, 0.5407521724700928, -0.7905962467193604, 0.34238505363464355, -0.8998452425003052, -0.6717230081558228, 0.1384742259979248, -0.5991702079772949, -0.45280134677886963, -0.5255485773086548, -0.7368806600570679, -0.53758704662323, -0.7259551286697388, 0.9868354797363281, -0.8551162481307983]), layer_3 = (weight = [-3.066320641664788e-5 -0.00016208204033318907 0.00012403492291923612 5.1166542107239366e-5 4.51845298812259e-5 8.569371129851788e-5 -7.148225995479152e-5 -0.00011333310249028727 4.00990029447712e-5 -2.8557087716762908e-5 -0.00017484757700003684 -5.276063166093081e-5 -8.461391553282738e-5 -5.6824203056748956e-5 -1.5841305867070332e-5 3.516722426866181e-5 -8.97798890946433e-5 -8.805507241049781e-5 5.8567733503878117e-5 -0.00011041424295399338 -9.166553354589269e-5 -0.0001346104545518756 -0.00013812865654472262 -0.00012434364180080593 2.252775084343739e-5 9.493881952948868e-5 5.235604112385772e-5 -3.95043462049216e-5 4.8868514568312094e-5 6.405822205124423e-5 -6.798324466217309e-5 0.0001304072793573141; 0.00016163234249688685 6.3927423070708755e-6 -3.755516809178516e-5 5.097933535580523e-5 -5.695356594515033e-5 -8.026776777114719e-5 1.5825964510440826e-5 -0.00033238768810406327 -0.00012743417755700648 2.8482645575422794e-5 9.502375178271905e-5 -8.803088712738827e-5 -1.6042032484619995e-6 3.23027074955462e-6 -0.00020780506019946188 -1.5358116343122674e-6 -2.771161462078453e-6 0.0002507960598450154 -4.3316475057508796e-5 -4.44072975369636e-5 1.2994172720937058e-5 5.862347825313918e-5 -2.9825001547578722e-5 -0.00010641366679919884 6.354024662869051e-5 6.736676732543856e-5 -2.6056592105305754e-5 -5.273430360830389e-5 4.7542482207063586e-5 -3.363144429613385e-8 3.4210570447612554e-5 -3.859898788505234e-5; 5.443279951578006e-5 6.471797678386793e-5 4.276385152479634e-5 -2.8502559871412814e-5 -3.7751326544821495e-6 -0.00014484886196441948 0.00023496657377108932 -1.5421619536937214e-5 2.017825681832619e-5 4.573751994030317e-6 0.00010298057168256491 -8.516784873791039e-5 4.784231714438647e-5 -0.000167734338901937 4.808999437955208e-5 0.00014157309487927705 0.00016199728997889906 9.928568033501506e-5 8.21498833829537e-5 0.00012435422104317695 -0.00010726531036198139 0.0002510518243070692 -0.00014175062824506313 -8.954718214226887e-5 -0.00012058134598191828 0.00010539258801145479 -1.8787523003993556e-5 -0.00010123863467015326 -4.080401049577631e-5 5.8521676692180336e-5 8.495256042806432e-5 0.00013503202353604138; -0.00013757272972725332 -0.00018927536439150572 -3.250950976507738e-5 -9.276693890569732e-5 0.00013743943418376148 -0.00011121622083010152 -0.0001250889035873115 -0.00011172029189765453 -6.32280352874659e-5 0.00014294794527813792 6.289049633778632e-5 0.00010335798287997022 -0.0001610355102457106 6.318500527413562e-5 -4.8823028919287026e-5 -2.2549253117176704e-5 -6.477511487901211e-5 -5.0420454499544576e-5 -9.517943544778973e-5 -1.6837404473335482e-5 0.00015117975999601185 3.34803517034743e-5 3.524025305523537e-5 -0.00016124980174936354 -5.1954521040897816e-5 0.00020028659491799772 -1.7100081095122732e-5 2.3424989194609225e-5 5.972875442239456e-5 -7.312718662433326e-5 4.086078115506098e-5 -9.338634117739275e-5; -7.035169255686924e-5 -0.00016396335558965802 5.324750964064151e-5 3.237143755541183e-5 9.259108628612012e-5 -0.00011410086153773591 8.62952001625672e-5 -7.95544619904831e-5 0.00012342524132691324 1.6911349121073727e-7 -0.00012100812455173582 -6.289059820119292e-5 0.00012148164387326688 7.060254574753344e-5 -0.0001208738685818389 -4.385069405543618e-5 -2.1862604626221582e-5 -6.524175842059776e-5 -2.3515788143413374e-6 0.0001875921880127862 -0.0001665600429987535 -0.00015281200467143208 2.1746474885731004e-5 1.4451544302573893e-5 -8.709109351912048e-6 -9.999443864217028e-5 2.9582351999124512e-5 -6.567369564436376e-5 -7.889764674473554e-5 0.00012336199870333076 -1.6373083781218156e-5 5.240089012659155e-5; 3.215737160644494e-5 0.00017174211097881198 -6.436769035644829e-5 9.610915731173009e-5 7.724755414528772e-5 6.55643452773802e-5 -0.00022910685220267624 -3.567989551811479e-5 6.555541040142998e-5 9.259075886802748e-5 -7.427591481246054e-5 -0.00013147061690688133 0.00014524688594974577 -2.025256799242925e-5 -4.731147419079207e-5 -0.0001770484959706664 -5.9182304539717734e-5 -0.0001169736497104168 -5.5630582210142165e-5 -7.347584323724732e-5 -4.8541976866545156e-5 -2.991698420373723e-5 2.5520421331748366e-5 0.00014857026690151542 5.135107585374499e-6 2.3291120669455267e-5 4.2443847632966936e-5 -3.185659443261102e-5 7.654835644643754e-5 3.226113403798081e-5 2.774548738671001e-5 -8.362841617781669e-5; -0.00010835041757673025 -3.686245690914802e-5 0.00019454717403277755 3.230377478757873e-5 -4.321483720559627e-5 -0.00010914765880443156 -4.1139857785310596e-5 -2.287914412590908e-6 -8.886688738130033e-5 -1.7083088096114807e-5 0.00016147768474183977 6.0316157032502815e-5 -7.065937825245783e-5 0.00012093110854038969 -0.0002104804152622819 -0.00022622512187808752 3.6913614167133346e-5 4.804161289939657e-5 -0.00011625155457295477 0.0001723332388792187 8.429029549006373e-5 -9.46438103710534e-6 5.563709419220686e-5 -8.99267106433399e-5 2.3078921003616415e-5 -0.00010308122000424191 -0.00021912339434493333 -0.0001791121467249468 4.933511718263617e-6 0.00010599909728625789 -0.00011643459583865479 0.00011786930554080755; -0.0001751450472511351 9.77992604020983e-5 0.00011858776269946247 -4.727124542114325e-5 4.98971639899537e-5 -0.0001596062647877261 -0.0001544783590361476 6.139549805084243e-5 -0.00016223295824602246 5.64647734790924e-6 -9.982697520172223e-6 0.00013951071014162153 -0.00011731923586921766 -9.023216989589855e-5 -1.012849224935053e-5 -0.00018665727111510932 5.735603554057889e-5 8.521597919752821e-5 -6.846401811344549e-5 5.323246296029538e-5 -7.143970287870616e-5 -6.471490632975474e-5 -9.713749750517309e-5 -0.00010063836089102551 -8.514572255080566e-5 4.247216202202253e-5 -0.00012559854076243937 -6.437741831177846e-5 -2.961927020805888e-5 -9.2150985437911e-5 7.897783507360145e-5 9.754821803653613e-6; 9.00228405953385e-5 -9.582464554114267e-5 9.974843123927712e-5 0.00018277709023095667 -6.139411561889574e-5 -1.952635830093641e-5 5.043117926106788e-5 -4.76417517347727e-5 -5.1915831136284396e-5 -5.078912363387644e-5 -0.0001176639852928929 -1.9604824046837166e-5 -4.3107262172270566e-5 7.499624189222232e-5 6.268055585678667e-5 0.00012511902605183423 -0.000110799512185622 -1.2293222425796557e-5 0.0002040079707512632 -7.68102690926753e-5 -6.365784793160856e-5 -0.0002171678061131388 0.0001137144718086347 8.638166036689654e-5 3.8846406823722646e-5 7.795066630933434e-5 -1.4519519936584402e-5 9.143730494542979e-6 4.498418638831936e-5 4.5712738938163966e-5 -7.943672244437039e-5 -0.0001529736036900431; -0.0002487634483259171 -6.947316433070228e-5 0.00010514772293390706 -0.00010588709847070277 -2.465005854901392e-5 -0.000113583606434986 3.660443326225504e-5 -9.491879609413445e-5 3.926403951481916e-5 1.8800343241309747e-5 -1.2617713764484506e-5 -0.0002101285062963143 9.51678302953951e-5 -9.760648390511051e-5 -1.5120195712370332e-5 -0.0001752411772031337 0.00012285026605241 9.534349374007434e-5 -0.0001095376064768061 -9.781374683370814e-5 -0.00010382624168414623 8.059813262661919e-5 4.384923158795573e-5 -3.4456847060937434e-5 1.9892257114406675e-5 0.0001437248574802652 4.0145117964129895e-5 4.653440191759728e-5 5.451406104839407e-5 1.285945199924754e-5 5.1909792091464624e-5 0.00020366586977615952; -0.00011216841085115448 -9.352545021101832e-5 1.0820324860105757e-5 0.00011497941159177572 6.333414785331115e-5 -2.958538243547082e-5 -5.484149028234242e-7 5.838753713760525e-5 3.6762830859515816e-5 -3.178475026288652e-6 6.456116534536704e-5 0.00012109539238736033 -0.0002050480543402955 3.482307874946855e-5 -1.2859150047006551e-5 0.000103768463304732 9.380684059578925e-5 -0.00012004763266304508 -5.693401544704102e-5 1.0675703379092738e-5 0.00024600850883871317 0.00011793351586675271 4.978631477570161e-5 3.638062935351627e-6 -1.3785427654511295e-5 4.576845822157338e-5 -0.00012773845810443163 3.9342954551102594e-5 5.677768058376387e-6 -0.00011390564031898975 3.0366209102794528e-5 5.233010142546846e-6; 2.1621508494718e-5 -6.603497604373842e-5 -0.00010154794290428981 -3.7421457818709314e-5 -9.906072227749974e-5 -6.896163540659472e-5 0.0001614856591913849 5.5673237511655316e-5 -0.0001957954082172364 4.018818435724825e-5 4.682844155468047e-5 0.00014438667858485132 -1.1681819160003215e-5 3.104835195699707e-5 -1.5794890714460053e-5 -0.0001276279508601874 4.5579545258078724e-5 -8.268237434094772e-5 1.7668644431978464e-5 -3.099075547652319e-5 1.4805748833168764e-5 -8.604606409789994e-5 -6.0001366364303976e-5 6.429431959986687e-5 -8.514489877597953e-7 -0.00011787220137193799 1.908417107188143e-5 0.00011041435209335759 4.7270761569961905e-5 2.7763360776589252e-5 -9.454790415475145e-5 -7.425008334394079e-6; -2.1574311176664196e-5 -2.6311769033782184e-5 -0.000145624260767363 -0.00010406075307400897 4.4424301449907944e-5 7.254353840835392e-5 3.059609662159346e-5 2.960028950838023e-6 9.371277701575309e-5 -6.059155202819966e-5 -0.00019529451674316078 6.148526153992862e-5 0.00013258564285933971 -0.00016935916210059077 6.575080624315888e-5 1.5875180906732567e-5 7.131070015020669e-5 -2.367636807321105e-5 0.0001257223339052871 7.738668500678614e-5 -6.217160989763215e-5 -5.730355951527599e-6 -0.0001791437534848228 -9.619543561711907e-5 6.381152343237773e-5 -2.3233909814734943e-5 -5.762921500718221e-5 0.00012693801545538008 -0.00019640980463009328 -6.176497117849067e-5 0.00024618851603008807 -0.00019073607109021395; -6.973165727686137e-5 0.00016393001715186983 -0.00010235659283353016 0.00014490926696453243 -6.611115531995893e-5 -0.00013216002844274044 3.0489303753711283e-5 -0.00016050801787059754 7.238626130856574e-5 7.409246791212354e-6 6.60271180095151e-5 -1.0795994967338629e-5 -0.00016162633255589753 -1.05804729173542e-5 -3.256920172134414e-5 -5.236438664724119e-5 -0.00017188723722938448 -3.219758582417853e-5 6.396635581040755e-5 0.00016437770682387054 -1.0848955753317568e-5 -0.00020082433184143156 3.631001163739711e-5 3.1758027034811676e-5 -2.8885220672236755e-5 2.9797476599924266e-5 0.0001742174499668181 0.00012714171316474676 -8.191414963221177e-5 6.97078721714206e-5 7.14637353667058e-5 0.000137034134240821; -7.570000161649659e-5 0.00010502369696041569 -1.3470016710925847e-6 9.852635412244126e-5 -7.248130714287981e-5 -7.234537042677402e-5 -2.122823389072437e-5 4.0759474359219894e-5 6.582406786037609e-5 -7.091985025908798e-5 3.043964534299448e-5 7.74763393565081e-5 3.9947011828189716e-5 7.922375516500324e-5 -5.422886533779092e-5 8.95944976946339e-5 0.00013370488886721432 0.0001143946428783238 -0.00018512856331653893 2.8345791633910267e-6 -8.401538252655882e-6 -7.745569746475667e-6 0.00010549587750574574 6.855932497273898e-6 7.406888471450657e-5 -5.573009548243135e-5 1.942682320077438e-5 -0.00017407213454134762 4.874295700574294e-5 0.00011683884804369882 -4.122505742998328e-6 2.7556838176678866e-5; 6.318239320535213e-5 -5.327867143023468e-7 6.385541200870648e-5 -0.00016248617612291127 7.975425978656858e-5 -0.000191142171388492 -5.9078953199787065e-5 5.668984613294015e-6 -1.8747243302641436e-5 -6.302775727817789e-5 -0.00012834352673962712 6.798004648089773e-8 0.00012352173507679254 0.00013674837828148156 -3.8609076000284404e-5 0.00010451606067363173 -0.00010724886669777334 -0.00010064574598800391 -6.82955578668043e-5 5.548913759412244e-5 0.00019383073959033936 -0.0002896118094213307 8.806363621260971e-5 -1.8685253962757997e-5 -0.00017268460942432284 4.6597451728302985e-5 -0.0001044773671310395 1.4620113688579295e-5 0.00014239450683817267 0.0001293497916776687 -0.00021476521214935929 7.983770046848804e-5; -7.336109410971403e-5 -3.550534165697172e-5 -8.465105929644778e-5 -7.954019383760169e-5 5.475783473229967e-5 1.0680539617169416e-6 8.671901014167815e-5 -1.732686723698862e-5 4.427932708495064e-6 -2.0321138435974717e-5 -7.647649181308225e-5 -7.291173824341968e-5 1.526052801636979e-5 5.0832863053074107e-5 0.0001400700566591695 -4.206240555504337e-5 -0.00010452727292431518 -8.1815873272717e-5 -6.098094308981672e-5 -5.751418575528078e-5 -4.6058026782702655e-5 -7.294563692994416e-5 1.238010372617282e-5 0.00017351575661450624 -6.152009154902771e-5 0.00020497704099398106 -8.0568999692332e-5 9.55104042077437e-6 -0.00023729742679279298 -0.00010660379484761506 3.5690670756594045e-6 1.3889041838410776e-5; -6.074308112147264e-5 1.4326351447380148e-5 -2.668652086867951e-5 1.9601170606620144e-6 -0.00010868177196243778 1.617435918888077e-5 0.0001299596333410591 4.317798448028043e-5 -3.890359221259132e-5 -0.0001710675423964858 2.273905192851089e-5 -0.00012905937910545617 -0.00013590522576123476 0.0002607076021376997 -0.00012426219473127276 -4.991480091121048e-5 5.834399053128436e-5 -6.386244422174059e-6 -8.270691614598036e-5 -8.425392479693983e-6 3.0067170882830396e-5 -8.343212539330125e-5 0.00026114319916814566 -5.679417881765403e-5 -5.157190753379837e-5 -7.826549699530005e-5 0.00011965966405114159 -6.63971368339844e-5 0.0001040586648741737 -5.7320299674756825e-5 0.00011982925207121298 3.87177751690615e-5; -3.130982076982036e-5 -7.180091051850468e-5 6.296856008702889e-5 3.1886509532341734e-5 -0.00020825497631449252 1.5046343833091669e-5 -0.00018462669686414301 -6.167479295982048e-5 0.00010625415598042309 8.564352901885286e-5 6.847112672403455e-5 -9.625674283597618e-5 -3.675335756270215e-5 -4.106189953745343e-5 1.5229858945531305e-5 6.760371616110206e-5 -2.723778561630752e-5 -1.6027952369768173e-5 7.372334221145138e-5 -3.0940689612179995e-5 -4.394532152218744e-5 -6.185149686643854e-6 -6.285694689722732e-5 -0.00011073664791183546 -7.768897921778262e-5 1.9851402612403035e-5 -7.893139809311833e-6 8.793458255240694e-5 0.00017658004071563482 -2.7939131541643292e-5 -0.00011810127762146294 -1.803226587071549e-5; 0.0001986967836273834 -1.1987847756245174e-5 0.00017685603233985603 8.110483759082854e-5 0.0001056439577951096 -5.548438639380038e-5 -5.367614721762948e-5 8.37731349747628e-5 8.704192441655323e-5 -2.1924473912804388e-5 -7.2476177592761815e-6 -0.00012712278112303466 4.483083830564283e-5 3.76258249161765e-5 -1.961865200428292e-5 -0.00023658806458115578 5.147798219695687e-5 0.0001764129992807284 5.0713704695226625e-5 -0.00021392651251517236 -0.00019777489069383591 -0.00014323773211799562 -0.00010341144661651924 0.00010857605957426131 -6.285469862632453e-5 0.0002517055254429579 -8.569101191824302e-5 0.00019514476298354566 -1.4831382031843532e-5 4.496162728173658e-5 -6.146720988908783e-5 -8.185344631783664e-5; 1.4746627130080014e-5 -1.858247742347885e-6 -4.814622297999449e-5 8.542099385522306e-5 -7.314206595765427e-5 -9.958537702914327e-5 5.842042446602136e-5 9.904891157930251e-6 -2.6083329430548474e-5 0.00017538138490635902 -3.23211497743614e-5 -8.859020454110578e-5 5.352229345589876e-5 7.389570964733139e-5 0.00015304316184483469 3.624664532253519e-5 -0.00012947418144904077 -8.60269574332051e-5 2.360279358981643e-5 8.914579666452482e-5 1.3862626474292483e-5 3.196934630977921e-5 -5.70637930650264e-5 -3.441702574491501e-5 -1.0618860869726632e-5 -0.00010106251284014434 2.6949624952976592e-5 2.8241904146852903e-5 0.00014759771875105798 -0.0001410322292940691 0.00011848979920614511 -0.00013958460476715118; -0.00010115728218806908 2.0939998648827896e-5 -6.505542114609852e-5 6.63632235955447e-5 2.8154969186289236e-5 -0.00013346225023269653 -5.6889250117819756e-5 0.00015346093277912587 5.177968887437601e-6 -5.2735183999175206e-5 6.955644494155422e-5 -1.757926474965643e-5 -6.448856584029272e-5 4.5354870962910354e-5 1.3689276784134563e-5 -7.890482811490074e-5 -3.786823435802944e-5 6.548718374688178e-5 -2.6990112019120716e-5 4.064413587911986e-5 4.740276199299842e-5 -0.00010507968545425683 -1.0512115295568947e-5 3.078015652135946e-5 6.861355359433219e-5 0.00011222385364817455 -3.651538281701505e-5 -0.00013350062363315374 1.8207536413683556e-5 -8.354349120054394e-5 1.8156896430809866e-6 -1.858351970440708e-5; 0.0001360656606266275 5.940016490058042e-5 3.22400446748361e-5 0.00013480307825375348 -3.2215123724199657e-7 -4.771231033373624e-5 1.4230568012862932e-6 6.177137402119115e-5 3.122276393696666e-5 -2.1518014818866504e-6 0.00020311416301410645 2.2963396986597218e-5 5.317318027664442e-6 3.9302387449424714e-5 8.032975893002003e-5 -0.00022411542886402458 -1.854369656939525e-5 -2.1121340978424996e-5 -1.3416326510196086e-5 1.1960672054556198e-5 -7.757317689538468e-6 -1.4184696510710637e-6 2.4919212592067197e-5 -8.083391730906442e-5 1.9215294742025435e-5 -2.0364544980111532e-5 3.597277782318997e-6 -4.513767271419056e-5 -7.64671276556328e-5 0.00010511809523450211 2.897430385928601e-5 -0.00011276114673819393; -8.25506285764277e-5 4.146117134951055e-5 -0.00019277560932096094 4.5782064262311906e-5 -6.83256148477085e-5 6.131696136435494e-5 -5.347165642888285e-5 -6.660153303528205e-5 0.00014792740694247186 -0.00015370876644738019 3.360252230777405e-5 2.9659629490197403e-6 8.282202179543674e-5 -0.00010371023381594568 -0.0002092696086037904 2.0165312889730558e-5 1.5897242064966122e-6 -4.989587978343479e-5 9.976803266908973e-5 4.867512689088471e-5 7.434738654410467e-5 -1.7284630303038284e-5 0.00011013571929652244 -3.293528425274417e-5 0.0001310309744440019 6.293200840445934e-6 7.629340689163655e-5 7.893281144788489e-5 -0.00010510552965570241 -0.00023015000624582171 -6.380729610100389e-5 0.0001340621238341555; 0.00012888865603599697 5.896265429328196e-5 -0.00032783509232103825 -1.1387808626750484e-5 0.00015898977289907634 -9.423246956430376e-5 -1.5929817891446874e-5 0.00010405903594801202 2.4859326003934257e-5 3.669784200610593e-5 -0.00010503850353416055 -7.783746696077287e-5 0.00011095857189502567 0.00015315349446609616 8.095514203887433e-6 0.00010063535592053086 -0.00017611027578823268 -0.00022151658777147532 -4.2956355173373595e-5 4.474334491533227e-5 -3.6020497645949945e-5 -9.094438428292051e-5 -2.4893583031371236e-5 1.3325802683539223e-5 7.477769395336509e-5 0.00021454651141539216 0.00015494044055230916 -0.0001825714425649494 5.271952613838948e-5 1.5875575627433136e-6 3.85277644454618e-6 0.00011455146159278229; -1.0510611900826916e-5 -7.950789586175233e-5 -1.7161590221803635e-5 5.092297215014696e-5 -7.652562635485083e-5 -0.0001944080722751096 0.00010505835234653205 -2.224738273071125e-5 -3.333374479552731e-5 -7.032260327832773e-5 -0.00013541193038690835 -9.046142076840624e-5 4.9246096750721335e-5 -6.852960359537974e-5 -5.940470509813167e-5 0.00013898716133553535 9.614507871447131e-5 -3.584153819247149e-5 -7.686789467697963e-5 1.326485835306812e-5 -8.698801684658974e-5 -1.1008335604856256e-5 0.00014101243868935853 -0.00011990972416242585 1.6569201122251798e-8 -8.309714758070186e-5 6.639973435085267e-5 -0.00022338086273521185 -0.00010617995576467365 -9.557948942529038e-5 0.00010713081428548321 -6.687887798761949e-5; 5.321419666870497e-5 0.0002093711809720844 4.1943512769648805e-5 7.755596016068012e-5 5.160763248568401e-5 0.00010474593000253662 6.916094571352005e-5 1.661634996708017e-5 -7.300738161575282e-6 -0.00015646120300516486 -0.00012868325575254858 -0.00016819646407384425 6.291149475146085e-5 7.368083606706932e-5 9.485799091635272e-5 -9.57761294557713e-5 0.00010897742322413251 -2.8192138415761292e-5 -3.5020813811570406e-5 1.472859366913326e-6 -7.347123755607754e-5 -0.00010769833897938952 9.925959602696821e-5 0.00016536931798327714 -9.312228939961642e-6 -5.87226495554205e-5 6.153684353193967e-6 5.504551882040687e-5 0.0001650980848353356 0.00016365373448934406 -7.19800591468811e-5 0.00012085911293979734; 8.185448677977547e-5 -0.0002509044425096363 0.00016438736929558218 3.684989860630594e-5 -0.0001319770817644894 7.415497384499758e-5 -7.256875960592879e-6 -9.543281339574605e-5 -9.114681347455189e-7 -7.585663115605712e-5 -8.449495680906693e-7 -1.5517665815423243e-5 8.333241567015648e-5 -1.0012804523285013e-5 -7.135060877772048e-5 5.1770566642517224e-5 -5.2280305681051686e-5 -3.696867133839987e-5 3.826117790595163e-6 -3.337403541081585e-5 6.479953299276531e-5 1.905702447402291e-5 1.922258707054425e-5 -8.121511928038672e-5 5.1716589950956404e-5 3.6385761632118374e-5 -2.090680482069729e-6 -2.4223041691584513e-5 -6.831937207607552e-5 -0.000101201978395693 0.00016518367920070887 1.2419262930052355e-5; 4.782592077390291e-5 0.0002538981207180768 2.927224886661861e-5 5.241387043497525e-5 0.00011946306040044874 7.2947159424074925e-6 -6.27700865152292e-5 5.007781874155626e-5 6.906302587594837e-5 -5.29628960066475e-5 8.064419671427459e-5 -1.8251754227094352e-5 0.00022122220252640545 2.9873002858948894e-5 0.00010101389489136636 -0.00015619756595697254 -1.8870787243940867e-5 2.0554507500492036e-5 -7.291935617104173e-5 4.421158882905729e-5 -9.02533283806406e-5 2.1787565856357105e-5 6.47784472675994e-5 7.849543180782348e-6 0.00011852902389364317 -0.00010065341484732926 0.00010128701251232997 -0.00013492906873580068 0.00021828086755704135 7.255125819938257e-5 -0.00017781740461941808 3.110429679509252e-5; -2.929770380433183e-5 0.00015416982932947576 0.00021400536934379488 4.300041837268509e-5 3.153541183564812e-5 0.00011102822463726625 0.00010752696107374504 0.00014255710993893445 -6.248289719223976e-5 -2.415996823401656e-5 -6.577889143954962e-5 7.223850843729451e-5 2.581466469564475e-5 -0.00011020142119377851 0.0001354891574010253 7.333311077672988e-5 -0.00011637696297839284 6.633606972172856e-5 3.548158929334022e-5 4.512405575951561e-5 4.899363261756662e-7 -0.00022627884754911065 8.74070028658025e-5 -4.492256266530603e-6 0.00016033749852795154 -5.692261038348079e-5 -0.00015505812189076096 -4.311618977226317e-5 -9.957677684724331e-5 4.511415318120271e-5 0.00012777639494743198 4.08276537200436e-5; -0.00018884602468460798 -8.252290717791766e-5 -0.00020275004499126226 -1.5996141883078963e-5 1.0948557246592827e-5 -1.2531214679256664e-6 2.1226431272225454e-5 -0.00013950101856607944 -6.244997348403558e-5 6.189096893649548e-5 8.317054380313493e-6 0.00011069288302678615 -6.061266685719602e-5 0.00023008671996649355 -8.424004045082256e-5 0.00011367996921762824 7.365242345258594e-5 8.56967963045463e-5 -2.683291859284509e-5 -4.542297756415792e-5 0.00011473443009890616 3.466572525212541e-5 1.8310263840248808e-5 6.882715388201177e-5 -2.7325693736202084e-5 -6.988814857322723e-5 4.2895411752397195e-5 -8.405042899539694e-5 5.892500121262856e-5 -2.113747177645564e-5 5.016414797864854e-5 -0.00011784827802330256; 0.00015335540228988975 0.00015822034038137645 -8.252597763203084e-5 -0.0001876905735116452 3.733846460818313e-5 -6.259478686843067e-5 2.44094044319354e-5 6.757416849723086e-5 4.258911940269172e-5 2.214270170952659e-5 -5.39605607627891e-5 5.171241355128586e-5 -0.00011817408812930807 6.020662112860009e-5 0.00011015203199349344 0.00010924333037110046 9.301801037508994e-5 3.6018584069097415e-5 -0.0001317275600740686 8.561091817682609e-5 7.678157999180257e-5 0.0001137170911533758 -2.7689500711858273e-5 -7.738934073131531e-5 5.929891631240025e-5 4.7231969801941887e-5 -0.00017537426901981235 -0.00022233165509533137 2.615223274915479e-5 -1.5492503735003993e-5 -0.00011023137631127611 -3.0445044103544205e-5], bias = [-0.05889701843261719, -0.04179108515381813, -0.1648654043674469, -0.15569154918193817, 0.06983161717653275, -0.1210302785038948, 0.15167997777462006, 0.05566473305225372, 0.09520909935235977, 0.07429686188697815, -0.12768569588661194, -0.07680357247591019, -0.03951648622751236, -0.15966807305812836, 0.11038712412118912, -0.1681824028491974, -0.014540030620992184, 0.12943999469280243, -0.044445622712373734, -0.08000396937131882, -0.06320227682590485, 0.15975520014762878, -0.07586683332920074, -0.08197452127933502, -0.010823795571923256, 0.06522437930107117, -0.15413528680801392, -0.05686446651816368, 0.13072612881660461, -0.17009639739990234, 0.017366250976920128, -0.17152534425258636]), layer_4 = (weight = [-5.104401134303771e-5 1.430771044397261e-5 9.083566692424938e-5 9.08354195416905e-5 0.00012574376887641847 -2.457673690514639e-5 4.524418181972578e-5 -5.439859705802519e-6 0.00010451808338984847 0.00010085060785058886 -3.267650754423812e-5 -2.9209875719971023e-5 0.00012008121848339215 -6.34966345387511e-5 -2.920067527156789e-5 9.803897228266578e-6 0.00011898516822839156 -6.592714635189623e-5 -3.837096301140264e-5 1.9270435586804524e-5 0.00013988837599754333 2.2800106307840906e-5 -3.2683379686204717e-5 -0.00024198752362281084 0.00015849457122385502 3.968206510762684e-5 1.4435376215260476e-5 6.920682062627748e-5 3.7964509829180315e-5 6.176536408020183e-5 -5.482961205416359e-5 -0.00014768887194804847; -0.0001717937848297879 -0.00016873606364242733 6.418734119506553e-5 9.887566193356179e-6 4.863642971031368e-5 -3.2233434467343614e-5 -0.00011780128261307254 6.40426151221618e-5 -0.0001320927549386397 0.0001294731191592291 -0.00016594848420936614 0.00013162210234440863 7.241301500471309e-5 0.00012479994620662183 -0.00011841402010759339 5.2480212616501376e-5 2.3162861907621846e-5 -0.0001412325946148485 -3.3989803341683e-5 3.416975596337579e-5 2.4569531888118945e-5 7.078854105202481e-5 7.537960482295603e-5 6.353248318191618e-5 -6.798524555051699e-5 3.2747389923315495e-5 0.0001862759527284652 -0.00015407652244903147 5.046162914368324e-5 0.0001370279787806794 -8.360035280929878e-5 -6.678457430098206e-5], bias = [0.15760058164596558, -0.15231953561306]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(compute_waveform(
    dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(ax, tsteps, waveform_nn_trained; marker=:circle,
        alpha=0.5, strokewidth=2, markersize=12)

    axislegend(ax, [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb)

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
  JULIA_CPU_THREADS = 16
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 16
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.