Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux, ComponentArrays, LineSearches, OrdinaryDiffEq, Optimization, OptimizationOptimJL,
      Printf, Random, SciMLSensitivity
using CairoMakie

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end
one2two (generic function with 1 method)

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end
soln2orbit (generic function with 2 methods)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end
d_dt (generic function with 1 method)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0)<1e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio1 "mass_ratio must be <= 1"
    @assert mass_ratio0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end
compute_waveform (generic function with 2 methods)

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defiing a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(Base.Fix1(broadcast, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1e-4)),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1e-4)),
    Dense(32 => 2; init_weight=truncated_normal(; std=1e-4)))
ps, st = Lux.setup(Xoshiro(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[0.00011609513; 0.0001824329; -1.9164814f-5; 0.00012400471; 9.6986194f-5; -2.3693841f-5; 5.679155f-5; 0.00015512874; 0.00018779625; 2.325075f-5; 5.8592865f-5; 2.7642405f-5; 4.1091324f-5; -6.140931f-5; -4.070845f-5; 9.16991f-6; -3.4629077f-6; 0.0001312868; -0.00015697468; -0.0001199511; 0.0001657595; 6.0088074f-5; -4.3416356f-5; -5.1955973f-5; -0.00014137432; -0.00024461583; 0.00015793058; -0.0002437934; 5.0474602f-5; 5.15484f-6; -0.00016157793; 0.00015413856;;], bias = Float32[0.72460866, 0.67704713, -0.7095343, 0.8524022, -0.06946266, 0.39859533, -0.9956064, -0.13793015, -0.11989486, 0.090450644, -0.39274597, 0.015400171, -0.137568, -0.8628316, -0.71100426, -0.27089024, -0.87879574, 0.19814658, -0.49775326, -0.29768002, 0.11228943, 0.9830328, -0.3844211, -0.552438, -0.7404752, -0.48175097, -0.92050135, -0.22908318, 0.33877707, -0.4599644, -0.63494956, -0.5810462]), layer_3 = (weight = Float32[-0.00010953119 6.604865f-5 0.00012730187 -0.00016428548 0.0001656609 5.4511216f-5 8.3122446f-5 -1.072206f-5 4.776551f-5 0.00021244102 -2.6193715f-5 0.00021454338 -4.497986f-5 -2.1200576f-5 0.00010622252 5.7849662f-5 3.994515f-5 0.00012919305 1.5488391f-5 -5.9900217f-6 2.865572f-5 -6.381755f-5 -6.890271f-5 -0.0001354001 -4.5772886f-5 -4.300227f-5 -4.122723f-5 -0.0001411418 -3.1270383f-6 3.0657648f-5 9.286812f-5 -5.5872588f-5; 0.00013250059 0.00017142249 -0.00019495844 0.00014710029 0.00021519061 6.7949663f-6 1.4987695f-6 -1.9901066f-5 3.940826f-5 5.656074f-5 -2.1348802f-5 -2.934664f-6 -0.00017032432 5.5940258f-5 -0.00019716416 2.643972f-5 -8.299655f-5 -2.6755559f-5 -6.5394f-5 -0.0001185755 -8.834825f-6 -0.0001254275 8.6588814f-5 1.6410411f-5 0.00015929455 0.00013939741 0.00010152984 -5.772553f-5 5.115179f-5 0.0001723157 8.563998f-5 0.00014392137; -3.445228f-5 4.4023724f-5 8.756983f-5 -0.00015210602 0.00013933162 -0.00016376816 0.00013247388 9.917079f-5 0.000101071724 0.00013512567 -1.0681356f-5 -0.00023419429 -0.00018826692 0.000117103125 6.432762f-5 -7.566561f-5 -0.00018575427 -0.0001177231 -8.505805f-5 -5.188041f-5 -0.00016814136 -9.943991f-5 -3.6785113f-5 -0.00013475823 4.191113f-5 -2.4915005f-5 7.415716f-5 6.83636f-5 -8.7083215f-5 3.0799114f-5 0.00010945062 -0.00015993032; 1.4886113f-5 3.666478f-5 -4.432485f-5 -0.00022964408 -7.671727f-5 -0.00018952548 -3.4969114f-5 -0.00013288781 3.8992523f-5 -6.1254694f-5 6.119209f-5 0.00012507722 9.5749085f-5 3.146007f-5 0.000121751895 -0.00019719027 4.7077097f-5 8.200977f-5 -0.00011035831 -0.00012634505 -8.630238f-5 -0.0001445531 -1.0020011f-5 0.000121665726 0.00013672574 -0.00013766905 0.00012644194 4.0910905f-5 -0.00013371129 -5.3203512f-5 2.3079798f-5 8.442474f-7; -0.00012701646 -2.2158163f-5 3.258297f-5 6.5792505f-5 6.0463157f-5 -5.2793548f-5 -3.5685437f-6 -3.2699267f-5 2.690033f-6 1.2751852f-5 6.214043f-5 -2.4238829f-5 -1.6917393f-5 -9.4296185f-5 -3.5015828f-6 1.0477847f-5 -5.990135f-5 0.00012924394 -6.2577885f-5 0.00026527487 -6.718552f-5 0.00012970601 -0.00022648319 7.502084f-5 3.8856047f-5 -2.5322806f-6 0.0001110194 2.131585f-5 -9.739144f-5 6.397485f-5 -0.00016924448 -3.584852f-5; -9.24227f-5 0.00012450093 4.458236f-5 -2.1957225f-5 -1.1195151f-5 0.00014478146 2.0822372f-5 -2.7267184f-5 -9.872042f-5 0.00013918195 2.1270391f-6 -6.629979f-5 2.8964003f-5 -4.214982f-5 -0.0001125333 -1.7495533f-5 4.2083742f-5 -0.00028511146 -0.00011499763 -7.670543f-5 -7.41145f-5 -0.00012324326 0.00011843017 -0.00011633075 4.5033183f-5 5.408f-5 6.798807f-5 0.00014908923 -1.39140675f-5 -0.00016824421 8.368162f-5 -0.00025270792; 1.6376205f-5 2.5954327f-5 5.8985235f-5 -0.00011759817 1.6095331f-5 -5.4303055f-5 1.415962f-5 -4.381689f-5 -6.7565857f-6 0.00010821041 -7.178856f-5 -9.949935f-5 -0.000107567575 3.868913f-5 -5.2810447f-5 -3.6729277f-6 -1.349121f-5 2.9047107f-5 -0.00022519694 -2.5696221f-5 -0.00016734848 5.1005023f-5 -9.603722f-5 0.00014674429 2.034533f-5 0.00024419712 -0.00010152968 0.00016341258 5.252176f-5 3.3702232f-5 3.125563f-5 -8.55921f-5; -0.00010150012 3.496425f-6 1.8230228f-5 -6.5011656f-5 5.8857127f-5 -8.487991f-5 -5.888421f-5 4.125531f-5 -0.00015342954 -4.483346f-5 -7.096589f-5 -6.073797f-6 -0.000100532015 -0.0001065593 -6.158965f-5 8.832828f-5 0.00027364964 -0.00012950992 -6.76212f-5 8.2323546f-5 2.2172066f-5 -3.843092f-5 -8.2011975f-6 -9.7354576f-5 -2.7876607f-5 -6.47817f-6 -6.0623883f-5 -4.6261423f-5 8.73596f-5 -6.618558f-5 -0.00021342246 -6.2400046f-5; 1.8712699f-5 -7.440827f-5 0.00029201567 -2.8630662f-5 6.932943f-5 0.00014232437 -0.00019182084 -4.6148518f-5 -5.6717854f-5 -2.718374f-5 -0.00010045401 -1.8461547f-5 -1.2250802f-5 7.86011f-5 -0.00017421799 -2.8133452f-5 -1.28667725f-5 -7.17822f-5 7.6540324f-5 6.5196123f-6 -4.5383476f-5 0.00010798455 -8.371623f-5 -0.0001509397 8.582271f-5 -9.7777745f-5 0.00011555781 -0.00014799653 9.7068296f-5 -0.00010697705 1.5245351f-5 -0.00017884625; -0.00017383236 -5.974505f-5 -1.2803262f-5 -5.5111228f-5 0.00013899137 0.00021640085 -0.00014166758 -8.3834966f-5 -4.5186345f-5 -2.6428517f-5 3.522038f-5 6.798277f-6 -8.228107f-5 2.0961885f-5 8.631225f-5 8.758528f-6 -5.1753057f-5 -2.1218024f-5 0.000117139105 2.954444f-5 -9.44247f-6 0.00012968284 -0.00016110539 -5.2220108f-5 0.0001691586 0.0001364045 -4.7265272f-5 1.951748f-5 -6.046291f-5 0.00011957896 3.4745728f-5 -6.3546344f-5; 8.068678f-5 0.00017562397 -0.00023190457 0.00014578977 4.5682158f-5 -0.00022750275 -6.916404f-5 2.3110917f-5 -7.977746f-5 -1.507516f-5 0.000112144466 -0.00012073535 -9.054144f-5 2.2484921f-5 6.208783f-7 -0.00012633175 -6.399908f-5 0.00013822866 -2.6126307f-5 -2.9908922f-5 5.9895312f-5 9.790055f-5 -6.527341f-5 6.922738f-5 -0.00013481406 -2.4315343f-6 0.000108373664 -0.00012139608 3.2071005f-6 -7.396096f-5 -4.5134315f-5 -3.8611754f-5; 3.3937362f-5 0.00015713001 0.00012165582 -0.00021106937 6.312149f-5 -9.4041505f-5 0.00012060876 -7.327914f-5 7.751909f-6 -9.998678f-5 0.00015560427 4.9430448f-5 0.000100638434 0.00013770327 -1.4536684f-5 0.00010045765 7.186872f-5 -0.00017839344 -4.27915f-5 8.73392f-5 6.633308f-5 -0.000113061804 -0.0001246236 -9.7663535f-5 -2.1984608f-5 6.0756836f-5 4.4831395f-5 8.123112f-6 -5.543459f-5 -2.251497f-5 -0.00014343557 -7.642344f-5; -0.00014609213 -8.550442f-5 -0.00010906914 0.00020463114 0.00021621711 -7.303436f-5 2.0264812f-5 -5.1462634f-5 -4.172627f-5 4.578152f-5 -5.498086f-5 0.00015476767 -3.0006871f-5 0.00016163009 3.3201137f-5 -6.088503f-5 3.0233266f-5 -5.5215078f-5 0.0002617644 3.3377957f-5 5.7611305f-5 -5.4666772f-5 -6.88598f-5 -5.55902f-5 0.00011739423 0.0002382795 -0.00015863964 4.2916334f-5 -9.4259005f-5 -8.032588f-5 4.6950427f-5 5.9171736f-5; -5.143713f-5 -7.920148f-5 -9.399648f-5 7.072797f-5 6.6397886f-5 1.8357485f-5 -9.203386f-5 -0.000108794564 0.00012558188 8.4609586f-5 -0.00014830206 -4.061901f-5 -3.5207264f-5 3.3960583f-5 -0.00010800746 -0.00020282886 2.4710394f-7 -4.3176584f-5 -0.0001065139 2.3620825f-5 -2.876161f-6 -0.00015371181 -8.66312f-5 -3.6186517f-5 -3.560275f-5 -1.8860688f-5 4.8293285f-5 -4.8146157f-5 0.00011891397 5.573308f-5 3.8735478f-5 -3.1584626f-5; -2.2604792f-5 -9.44439f-5 0.00015723273 -1.9918043f-5 7.74656f-5 3.4090157f-5 -0.00013546941 -0.00015931667 -0.00021294232 -6.810397f-5 3.4354158f-5 -3.6979483f-5 0.00010238097 0.00018340524 0.00012042145 -0.00017946739 8.432318f-5 0.00020150449 -2.035204f-5 0.00016522832 9.0808906f-5 -2.549776f-5 -0.00017544936 0.00013107437 -7.7941964f-5 -0.00011260641 -0.000102129336 -2.5618692f-5 -7.617219f-5 -4.938619f-5 -0.000114368435 4.2881013f-5; 5.1070583f-5 -8.180704f-5 -9.30507f-5 -1.8658679f-5 7.146141f-5 0.000118009404 8.6850494f-5 9.3688286f-5 5.324937f-5 -4.097892f-5 -2.447087f-5 -2.4841902f-5 -4.681098f-5 0.00018394584 -8.5698375f-6 -3.146314f-5 0.00014627689 -9.377641f-6 6.887984f-5 -5.896103f-5 7.392498f-5 5.791824f-5 -4.5358003f-5 0.00010579132 9.1505455f-5 -6.3640015f-5 0.00011414593 2.07737f-5 -5.2916665f-5 -7.149039f-6 -5.5375727f-5 -8.087161f-5; -4.175087f-5 -9.293265f-5 0.0002340034 -2.8112343f-5 -0.00015360076 6.1110404f-5 -0.00015173326 -7.901318f-5 -0.000110860536 0.00011631595 -6.7414694f-5 -0.000118982956 6.319096f-5 -1.1270704f-5 -3.6331832f-5 -6.790778f-5 -6.8953676f-5 -0.00027772583 -8.059383f-5 0.00010588332 0.000118110955 0.00011533292 0.00010384234 -4.8750913f-5 -0.00016394061 5.7716086f-5 0.00013010275 6.719411f-5 -4.815047f-5 0.00012288871 -3.9511746f-5 0.00013423865; 4.7511996f-5 -7.2555005f-5 -0.0001285313 -7.943679f-6 9.8890625f-5 -1.8899289f-5 -6.926898f-5 1.1490224f-6 9.685734f-5 8.413714f-6 2.392046f-5 -2.161496f-5 4.35911f-5 1.9015934f-5 -2.1413785f-5 -6.9865266f-5 -0.00015239863 2.9802195f-5 3.4152094f-5 2.2866418f-5 4.034838f-5 2.8668439f-5 0.0001812543 6.6207074f-5 5.649964f-5 -1.2737948f-5 8.680912f-5 4.2196305f-5 9.1414026f-5 0.00011624009 3.5815845f-7 -0.00012488346; -0.00010706039 -6.441293f-5 -6.967299f-5 -4.897599f-5 8.150134f-8 -2.4615343f-5 -0.000108635526 -8.5460175f-5 1.8646819f-5 -1.7390481f-5 3.4031846f-6 -6.790847f-5 7.817524f-5 -1.3465097f-5 -6.435852f-5 -2.438516f-5 8.152233f-5 -8.602482f-5 3.6880538f-5 -4.5904135f-6 -3.1499705f-5 -1.8185252f-5 -0.00017559945 -3.6940888f-5 -2.1071979f-5 2.4607614f-5 3.8020415f-5 7.762497f-5 5.9322527f-5 -0.0001488028 -0.00017587368 9.424423f-5; -4.6827136f-5 -1.1468838f-5 1.164874f-5 0.0001158402 7.883953f-6 -5.800085f-5 -0.00015876823 5.0458388f-5 -0.000181207 1.5050576f-5 5.313777f-5 -0.00022119493 6.52978f-6 -0.00020138886 -1.452234f-5 8.318006f-5 -0.000121831596 5.1155133f-5 -7.028883f-6 -2.2948107f-5 -7.644979f-5 8.799134f-5 0.00014802 4.8733065f-5 -2.4215442f-5 -9.527473f-5 9.174768f-5 -5.138024f-6 -0.00013282531 2.6118914f-5 -8.209455f-5 -7.768508f-5; 3.490864f-6 0.00022217075 5.8998332f-5 -0.00010397887 1.3332827f-5 -1.1499459f-5 3.227698f-5 5.472335f-5 4.1158473f-5 -1.5073618f-5 0.00010191958 -0.00017565921 0.0001333506 0.00017612177 -0.00012053103 -4.9752274f-5 -0.00011244177 4.072293f-6 0.000105150735 0.000107615895 7.81312f-5 3.191294f-5 0.00018457408 -0.00013643166 1.4855038f-5 0.00019894006 -1.7730366f-5 8.651031f-5 1.873006f-5 0.00013409156 -6.0886305f-6 7.6423195f-5; 5.0080824f-5 1.2904572f-5 0.000100075755 -5.158116f-6 8.031713f-5 9.045945f-5 0.00015096953 -0.0002003874 -2.1695245f-5 -9.87355f-5 2.5714673f-5 0.00020369283 -7.464664f-5 -8.577439f-5 4.0858067f-5 -0.00010886756 0.000119647186 -3.4766348f-5 -0.00013316511 1.4343697f-5 -4.242266f-5 0.0001529448 -0.00017318393 8.208438f-6 -1.1641294f-5 -3.135771f-5 0.00011026946 0.00014742526 8.5110245f-5 -1.0362523f-5 -0.00022382458 -5.105686f-5; -0.00018035683 -2.7712214f-5 -1.090769f-5 0.00011432961 0.0001452086 -3.0128467f-5 -6.851188f-5 0.000120348006 5.1716233f-5 -5.8855876f-5 0.0001291847 0.00024216052 3.5891793f-5 -0.00010693854 1.5688943f-6 -9.7378994f-5 -9.249981f-5 2.2374821f-5 1.5046087f-5 0.00022592091 0.00010478912 -0.00015276321 9.012196f-5 -2.4467421f-5 -0.00013231419 -0.0001302618 0.00011204244 -3.538838f-6 -8.951787f-5 4.8070626f-5 2.7315384f-6 8.0200494f-5; 8.5288586f-5 -3.440117f-5 -5.5557248f-5 8.1726546f-5 3.4609966f-5 -1.7090104f-5 -1.5557448f-6 -2.982441f-5 5.7987858f-5 -0.0001479686 2.2409291f-5 8.994695f-6 -2.5444793f-5 4.095882f-5 -7.6245735f-5 -0.00011693255 -7.123567f-5 4.5636192f-5 5.527214f-6 8.1080114f-5 0.00024678878 0.00019591497 3.7855716f-5 -1.10810315f-5 -6.089936f-5 -2.1687183f-5 -6.195716f-5 0.00012107089 5.9502658f-5 9.347119f-5 -3.0394296f-5 -7.1534654f-5; -4.562505f-6 4.783416f-5 0.00013670066 -1.3297493f-5 6.8803834f-5 1.5581074f-5 0.00011613401 -8.056028f-5 -2.222458f-5 -8.684887f-5 0.00021297844 0.000177127 -0.00011377454 0.000108276705 9.6647454f-5 -9.606093f-6 -4.2647807f-5 8.383242f-5 -9.4578936f-5 -3.7995153f-5 0.00015998527 5.63069f-6 9.380092f-5 -4.826478f-5 0.00010404007 -9.5258205f-5 0.000131682 -5.0254508f-5 -7.9704514f-5 0.00018303504 0.00010047717 -8.443969f-7; 5.3144002f-5 -2.1409609f-5 0.000110268076 -3.1317068f-5 0.00014925812 -4.359057f-5 5.238928f-5 2.8074779f-5 -0.00013524368 1.879043f-5 -0.00017701392 -0.00012368085 -0.00014045897 -0.000119126205 -0.00019324684 3.9683815f-5 -2.0155127f-5 -7.968871f-5 0.00023778432 0.00013058062 7.060011f-5 8.1908576f-5 -0.00020509615 -8.2936866f-5 -6.770603f-6 -9.018112f-5 7.086061f-5 -4.115589f-5 2.0886306f-5 5.072067f-5 1.45232325f-5 -0.00012433798; -1.4101492f-5 -6.338118f-5 6.8959846f-5 -3.1885123f-5 8.481837f-5 -7.778899f-5 1.5720843f-5 8.978259f-5 0.00015647848 0.00011613626 -9.673388f-5 0.00012241119 3.0302948f-5 -1.934997f-5 5.7149475f-5 9.455143f-5 -1.0076608f-5 0.00013894406 -0.00015163014 -6.975489f-5 -0.0001454385 1.712966f-5 0.00012774052 0.00015430334 -0.00019867894 2.5786552f-5 0.00019528759 -7.830418f-5 4.0972092f-5 -0.00011121598 0.00012568707 4.9551527f-5; -0.00010344463 -7.184371f-5 -2.2790937f-5 -0.00013546678 6.3312264f-6 8.0119695f-5 -1.0072208f-5 -5.8009904f-5 -8.744825f-5 0.00010243698 -0.00024118231 0.00015697314 1.8419998f-5 -0.00018567388 -4.6180092f-5 1.5837469f-6 6.487918f-5 -7.230017f-5 -9.400511f-5 0.00017190231 0.00018656078 2.358007f-5 -3.529334f-5 0.00019954021 -0.0002287499 -8.693445f-5 -2.3161836f-5 0.00019836146 -1.5394467f-5 -7.7795434f-5 4.8569917f-5 3.6448877f-5; -0.0001249033 0.00014808126 3.1817744f-5 0.00010703517 -1.7187276f-5 -7.647966f-5 -0.00015130066 1.3193134f-5 -5.8106132f-5 4.2845455f-5 -8.551374f-5 -4.163147f-5 1.9591036f-5 1.568495f-5 5.4854787f-5 4.6065717f-5 9.056142f-5 1.8365504f-6 8.68557f-5 -3.353886f-5 9.6500735f-6 7.3499614f-5 4.4601402f-5 -0.000106108426 0.00011237036 0.00011681157 0.00015264702 5.234341f-5 -5.20412f-6 -4.22729f-5 -0.00012911715 2.647209f-5; -7.938373f-6 4.2673844f-5 -2.6617585f-5 1.656783f-5 7.092108f-5 6.202349f-5 -2.793506f-5 8.376832f-5 -8.1233404f-5 2.1210646f-5 1.7968765f-5 2.0627192f-5 -5.363186f-5 7.207435f-5 1.4533075f-5 -0.00014739267 -1.622916f-5 -5.8581786f-6 -0.00011486998 0.00012364432 5.197422f-5 -6.0508173f-5 2.633796f-5 0.00010490321 0.0001606455 2.0535725f-5 8.924283f-5 5.522921f-5 0.00020470467 5.3036092f-5 -0.00011954976 -0.00010241606; 0.0001772364 -2.4458275f-5 -6.494705f-5 -0.00018576138 -2.0019783f-5 8.96027f-6 -3.2290303f-5 0.00024863076 6.577228f-6 6.219537f-5 0.00016193472 -3.0265424f-5 2.0175483f-5 3.672254f-5 0.00019269387 1.4662384f-5 0.00011587325 0.00015950226 2.7437158f-5 4.7399426f-5 2.7846133f-6 -5.6620698f-5 -2.9625403f-6 9.826812f-5 -0.00012319097 -2.7650585f-5 9.994065f-6 0.00014062108 6.1660096f-5 -2.3807688f-5 5.3463817f-5 -8.0567545f-5; -1.14582035f-5 -5.5711986f-5 9.586469f-5 -1.977659f-5 -5.97579f-5 -0.00013538754 -0.00011371761 -2.0802312f-5 -9.9069046f-5 -0.00016055588 7.007411f-5 -7.7660385f-5 1.8760686f-5 4.1213527f-5 -4.022227f-5 0.00011072605 -0.00012354745 8.556358f-5 0.000123402 -2.8348334f-6 8.509084f-5 -5.2290034f-5 -4.3382174f-6 0.00010785373 1.2645864f-5 2.1994296f-5 -1.4214328f-5 -6.7487184f-5 -0.0001204553 0.00014635186 -3.4790253f-5 9.4295436f-5], bias = Float32[0.08470798, 0.086456865, -0.11841113, -0.14723071, -0.09460315, -0.06184378, 0.13951914, 0.12184121, 0.13900806, 0.10256484, 0.091202304, -0.14918911, 0.04318183, -0.0732924, 0.1212918, -0.07931831, -0.079690754, -0.09501335, -0.080945976, -0.15352905, 0.013736311, -0.16910446, 0.06395997, -0.021855332, -0.050054863, 0.038203716, 0.094521575, 0.14302266, -0.023378309, -0.09521308, 0.09814939, 0.045443974]), layer_4 = (weight = Float32[-0.0002391026 -4.4073084f-5 -4.8281745f-5 0.00012762523 -0.00018221706 -9.343411f-6 0.0001720693 -5.1509847f-5 -1.9166316f-6 -3.2365213f-5 2.6117988f-5 -1.836218f-5 1.7979635f-5 -6.0436996f-5 -0.00014622996 -1.0849271f-5 -3.001778f-5 -7.928423f-5 0.00013554374 0.00011687534 1.0570678f-5 -1.872465f-5 -8.609983f-5 -0.000104758285 0.00028195392 -6.889872f-5 6.0980306f-6 5.1694522f-5 8.372602f-6 0.00017695039 9.337201f-5 -5.748306f-5; 0.00015297905 -0.00010769865 4.1052795f-6 -0.00014262322 -0.00015561024 0.00016070156 -3.608102f-5 -2.5221105f-5 -6.713333f-5 -3.76401f-5 0.00018570488 -0.00012350453 0.00015100559 5.4465716f-5 -1.3783012f-5 4.644399f-5 -6.506983f-5 9.527861f-5 0.0002962584 -1.9288274f-5 5.9503225f-5 -8.325627f-5 8.625081f-5 -0.00010482496 -6.0544044f-5 -3.2250806f-5 -0.00015356197 7.6312586f-5 -0.00018434595 -0.00011259126 0.000113505725 -0.00010150186], bias = Float32[0.12655267, -0.17264202])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most DL frameworks, Lux defaults to using Float32, however, in this case we need Float64

julia
const params = ComponentArray{Float64}(ps)

const nn_model = StatefulLuxLayer{true}(nn, nothing, st)
StatefulLuxLayer{true}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(broadcast), typeof(cos)}(broadcast, cos)),
        layer_2 = Dense(1 => 32, cos),  # 64 parameters
        layer_3 = Dense(32 => 32, cos),  # 1_056 parameters
        layer_4 = Dense(32 => 2),       # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)

    axislegend(ax, [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"]; position=:lb)

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(waveform, pred_waveform), pred_waveform
end
loss (generic function with 1 method)

Warmup the loss function

julia
loss(params)
(0.001718746522903629, [-0.021281534073881572, -0.02067247907440869, -0.02006342407493576, -0.01902857079673745, -0.017536991190539582, -0.015542510383375108, -0.012982006101712316, -0.009769558987395979, -0.005795047419457685, -0.0009177756313545156, 0.005031118106177573, 0.012230946026558552, 0.020814381578679375, 0.03066167850122043, 0.0408772673433187, 0.04856111183952473, 0.04736177172649214, 0.030425578795192518, 0.0006542124833796516, -0.0282973861370566, -0.04389177535932985, -0.04530738530432361, -0.03915606972419651, -0.030631125133237593, -0.022105401016321774, -0.014401263962028153, -0.007708053707624508, -0.0019882113715785103, 0.002863653010017779, 0.006958215774508699, 0.010395502865556749, 0.013256871205965762, 0.015606740929220828, 0.017493685356088393, 0.01895277518188356, 0.020005540919795114, 0.020663151037491286, 0.020924843882097782, 0.020778360549544263, 0.0201986698260963, 0.019146148670682, 0.017562880606350268, 0.015371083796996787, 0.012461118003897884, 0.008690982119773904, 0.0038743746753949026, -0.0022262988641122672, -0.009884597700463392, -0.019347297552307276, -0.030592541682543665, -0.04261419573457452, -0.05180847489149728, -0.05052838085055291, -0.03220877624224501, -0.0019521451965597572, 0.025850922495722014, 0.040230036612992334, 0.042087961841345965, 0.03742964731238321, 0.0304783541872911, 0.0232056241542455, 0.016378572722369422, 0.010237034338734043, 0.0048153269252658275, 6.974901810712093e-5, -0.004063642905496626, -0.007648640690326681, -0.010741101007937449, -0.013388022431403721, -0.01561999317345014, -0.01746622730150611, -0.0189359534725486, -0.020036578277024337, -0.020758986453740517, -0.021085948841495467, -0.020985957324662987, -0.02040934033850495, -0.01929049953301399, -0.017529578883854587, -0.015001649041981263, -0.011521741862422394, -0.006857777299463069, -0.0006867088650526323, 0.0073749919043595, 0.017720311373738442, 0.030448149030206435, 0.044382367726033316, 0.05497831434309598, 0.05327328889243977, 0.033624957311166305, 0.00323794999215949, -0.023129168026307927, -0.03647413202793943, -0.038918383947691794, -0.03567660649812974, -0.030182370131635054, -0.024101013991554822, -0.018144479363434943, -0.012594911720089946, -0.0075318768150407145, -0.0029689086072989413, 0.0011273729235501404, 0.0047827756633678515, 0.008038634997018377, 0.01091604914238395, 0.013443678729759219, 0.015629842252823416, 0.017486154318374594, 0.01900597945990168, 0.02018035112178246, 0.020983671169653658, 0.021378935562977444, 0.021311195496985128, 0.020699343697977393, 0.019435821508123775, 0.017361054978744496, 0.014266166833427353, 0.009844161992604935, 0.003696733014436915, -0.004713946466457221, -0.015952911219305813, -0.030255223120556042, -0.046185336240106865, -0.057965977685766355, -0.055492175099363314, -0.034648171845518674, -0.004534727533625308, 0.02018358171476839, 0.03270128040293837, 0.03579979598273278, 0.03387510435532759, 0.029726547961255137, 0.024774662122668264, 0.019682460627944164, 0.01475363417058562, 0.010107652324654022, 0.005789079560094676, 0.001811071050583945, -0.0018472034838976342, -0.005192262492324004, -0.008241734447932407, -0.010999751672255113, -0.013482363111040388, -0.0156825938488354, -0.01759523056527638, -0.019201108688487695, -0.020476797196681856, -0.021376662821388972, -0.021839003737469345, -0.021778239727538055, -0.021065344025744774, -0.019523579743564373, -0.016896018802976295, -0.012826126251165759, -0.006786205662988636, 0.0019091058002922198, 0.014064330795398256, 0.030042943940964566, 0.04800556269373392, 0.06063444568200903, 0.057108451395786235, 0.0352702132407894, 0.0058798793994606505, -0.0170795025643811, -0.028962581054619584, -0.0327238801183047, -0.03200610215669037, -0.02909500029120271, -0.025212189299814013, -0.0209730768219115, -0.016686862392240757, -0.012502993490404691, -0.008498903848039282, -0.0046960693553385205, -0.00111113099026289, 0.002257588824162635, 0.0054047115286055035, 0.0083427774768219, 0.011063380149310693, 0.013564121131369102, 0.015829022080640923, 0.017845664197721604, 0.019578441521750523, 0.020983023506614767, 0.02199357662779568, 0.022514381947017825, 0.022406732603486735, 0.02146787633157571, 0.019407512081060344, 0.01578146075407358, 0.009953349340560951, 0.0010240931360573906, -0.01206592409646113, -0.029848344965305588, -0.049792592711338006, -0.06283327083258884, -0.058072179486783004, -0.0355181952967624, -0.007314608051774, 0.013892648834255384, 0.025282084696022798, 0.029682877319942673, 0.030052587174006613, 0.028274662720990322, 0.025399502910755697, 0.02199512392571537, 0.01836885703788994, 0.014686803847964018, 0.011045566243594127, 0.007485636867580499, 0.004038833695348922, 0.0007157401371815004, -0.002468360655184809, -0.005518695740468013, -0.008421641694387425, -0.0111712335172603, -0.013748362757146592, -0.016140190442508872, -0.01830949290507676, -0.02021133705745073, -0.021776772445988538, -0.02290690648561866, -0.023451528606095614, -0.023187083031805965, -0.021780923456963662, -0.018709334617071367, -0.013192832150069948, -0.004083677839845506, 0.009998117457218458, 0.029688637031439664, 0.051457986818368294, 0.06443178036674987, 0.058352488975589485, 0.035465480531822635, 0.008858742092932087, -0.010681707816491373, -0.021684457193950648, -0.02666021460432403, -0.02800104291640796, -0.027253924184988898, -0.025324172687275083, -0.022730878523676272, -0.019772390458767587, -0.016622432128257113, -0.013389870740278908, -0.010129030390667162, -0.006882271145052196, -0.003669830423792736, -0.0005139396024993722, 0.0026419512187940905])

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l, pred_waveform)
    push!(losses, l)
    @printf "Training %10s Iteration: %5d %10s Loss: %.10f\n" "" length(losses) "" l
    return false
end
callback (generic function with 1 method)

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob, BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback, maxiters=1000)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [0.00011609512876020744; 0.0001824328937800601; -1.916481414809823e-5; 0.00012400471314322203; 9.698619396658614e-5; -2.3693841285421513e-5; 5.67915485589765e-5; 0.00015512874233536422; 0.0001877962495200336; 2.325075001863297e-5; 5.8592864661477506e-5; 2.7642405257211067e-5; 4.1091323510045186e-5; -6.140930781839415e-5; -4.070845170645043e-5; 9.169910299533512e-6; -3.462907670837012e-6; 0.0001312867971137166; -0.00015697468188591301; -0.00011995110253337771; 0.00016575949848629534; 6.00880739511922e-5; -4.34163557656575e-5; -5.1955972594441846e-5; -0.00014137431571725756; -0.00024461583234369755; 0.0001579305826453492; -0.00024379340175073594; 5.047460217610933e-5; 5.154839982424164e-6; -0.00016157793288584799; 0.00015413855726365;;], bias = [0.7246086597442627, 0.6770471334457397, -0.7095342874526978, 0.8524022102355957, -0.06946265697479248, 0.39859533309936523, -0.9956064224243164, -0.13793015480041504, -0.11989486217498779, 0.09045064449310303, -0.3927459716796875, 0.015400171279907227, -0.13756799697875977, -0.8628315925598145, -0.7110042572021484, -0.2708902359008789, -0.8787957429885864, 0.19814658164978027, -0.4977532625198364, -0.2976800203323364, 0.1122894287109375, 0.9830328226089478, -0.38442111015319824, -0.5524380207061768, -0.7404751777648926, -0.4817509651184082, -0.9205013513565063, -0.22908318042755127, 0.3387770652770996, -0.459964394569397, -0.6349495649337769, -0.5810462236404419]), layer_3 = (weight = [-0.0001095311890821904 6.604864756809548e-5 0.00012730187154375017 -0.00016428547678515315 0.00016566089470870793 5.451121614896692e-5 8.312244608532637e-5 -1.0722060324042104e-5 4.776550849783234e-5 0.00021244102390483022 -2.6193714802502654e-5 0.00021454338275361806 -4.4979860831517726e-5 -2.1200576156843454e-5 0.00010622252011671662 5.784966197097674e-5 3.9945149183040485e-5 0.00012919305299874395 1.5488390999962576e-5 -5.99002169110463e-6 2.865572059818078e-5 -6.381754792528227e-5 -6.890270742587745e-5 -0.0001354000996798277 -4.577288564178161e-5 -4.30022701038979e-5 -4.122723112232052e-5 -0.00014114180521573871 -3.1270383260562085e-6 3.0657647585030645e-5 9.286811837228015e-5 -5.587258783634752e-5; 0.0001325005869148299 0.00017142249271273613 -0.00019495844026096165 0.00014710029063280672 0.00021519060828723013 6.794966338929953e-6 1.4987695067247842e-6 -1.9901066480088048e-5 3.94082599086687e-5 5.6560740631539375e-5 -2.13488019653596e-5 -2.9346640531002777e-6 -0.0001703243178781122 5.594025788013823e-5 -0.00019716416136361659 2.6439720386406407e-5 -8.299655019072816e-5 -2.6755558792501688e-5 -6.539400055771694e-5 -0.00011857550271088257 -8.834825166559312e-6 -0.00012542749755084515 8.658881415612996e-5 1.641041126276832e-5 0.00015929454821161926 0.00013939740892965347 0.00010152984032174572 -5.7725530496099964e-5 5.115179010317661e-5 0.00017231570382136852 8.563997835153714e-5 0.00014392136654350907; -3.4452281397534534e-5 4.4023723603459075e-5 8.756983152125031e-5 -0.00015210601850412786 0.00013933161972090602 -0.0001637681561987847 0.00013247388415038586 9.91707929642871e-5 0.00010107172420248389 0.00013512566511053592 -1.0681355888664257e-5 -0.00023419428907800466 -0.0001882669166661799 0.00011710312537616119 6.432762165786698e-5 -7.566560816485435e-5 -0.00018575426656752825 -0.00011772310244850814 -8.505804726155475e-5 -5.188041177461855e-5 -0.00016814135597087443 -9.94399088085629e-5 -3.678511347970925e-5 -0.00013475822925101966 4.191113112028688e-5 -2.4915005269576795e-5 7.415715663228184e-5 6.836360262241215e-5 -8.708321547601372e-5 3.0799114028923213e-5 0.00010945062240352854 -0.0001599303213879466; 1.4886112694512121e-5 3.666478005470708e-5 -4.432485002325848e-5 -0.00022964407980907708 -7.67172678024508e-5 -0.0001895254827104509 -3.4969114494742826e-5 -0.00013288781337905675 3.899252260453068e-5 -6.125469371909276e-5 6.119209137978032e-5 0.00012507721839938313 9.574908472131938e-5 3.146006929455325e-5 0.00012175189476693049 -0.00019719026749953628 4.7077097406145185e-5 8.200977026717737e-5 -0.00011035831266781315 -0.0001263450540136546 -8.63023815327324e-5 -0.00014455309428740293 -1.0020011359301861e-5 0.00012166572560090572 0.00013672573550138623 -0.00013766904885414988 0.00012644194066524506 4.0910905227065086e-5 -0.0001337112917099148 -5.320351192494854e-5 2.3079797756508924e-5 8.442473813374818e-7; -0.0001270164648303762 -2.215816311945673e-5 3.2582971471128985e-5 6.579250475624576e-5 6.046315684216097e-5 -5.279354809317738e-5 -3.568543661458534e-6 -3.2699266739655286e-5 2.6900329430645797e-6 1.275185240956489e-5 6.214043241925538e-5 -2.4238828700617887e-5 -1.6917392713367008e-5 -9.429618512513116e-5 -3.5015827961615287e-6 1.0477847354195546e-5 -5.990135105093941e-5 0.00012924394104629755 -6.257788481889293e-5 0.00026527486625127494 -6.718552322126925e-5 0.00012970600801054388 -0.00022648318554274738 7.502084190491587e-5 3.885604746756144e-5 -2.5322806322947145e-6 0.00011101939890068024 2.131584915332496e-5 -9.739144297782332e-5 6.39748468529433e-5 -0.000169244478456676 -3.58485194738023e-5; -9.242269879905507e-5 0.00012450093345250934 4.458236071513966e-5 -2.195722481701523e-5 -1.1195151273568626e-5 0.0001447814574930817 2.08223718800582e-5 -2.7267184123047628e-5 -9.872041846392676e-5 0.00013918195327278227 2.12703912438883e-6 -6.629979179706424e-5 2.896400292229373e-5 -4.214981890982017e-5 -0.00011253330012550578 -1.749553302943241e-5 4.208374230074696e-5 -0.0002851114550139755 -0.00011499763058964163 -7.670542981941253e-5 -7.411449769278988e-5 -0.00012324325507506728 0.00011843017273349687 -0.00011633075337158516 4.503318268689327e-5 5.407999924500473e-5 6.79880686220713e-5 0.00014908923185430467 -1.391406749462476e-5 -0.00016824420890770853 8.36816179798916e-5 -0.0002527079195715487; 1.637620516703464e-5 2.5954326702049002e-5 5.8985235227737576e-5 -0.000117598166980315 1.6095331375254318e-5 -5.430305463960394e-5 1.415962015016703e-5 -4.381688995636068e-5 -6.7565856625151355e-6 0.00010821041360031813 -7.178856321843341e-5 -9.949935338227078e-5 -0.00010756757546914741 3.868912972393446e-5 -5.281044650473632e-5 -3.672927732623066e-6 -1.349120975646656e-5 2.904710709117353e-5 -0.00022519694175571203 -2.5696221200632863e-5 -0.00016734848031774163 5.100502312416211e-5 -9.603721991879866e-5 0.00014674429257865995 2.0345329176052473e-5 0.0002441971155349165 -0.00010152968025067821 0.00016341258015017956 5.2521758334478363e-5 3.370223203091882e-5 3.1255629437509924e-5 -8.559210255043581e-5; -0.00010150011803489178 3.4964250517077744e-6 1.823022830649279e-5 -6.501165626104921e-5 5.885712744202465e-5 -8.48799099912867e-5 -5.888420855626464e-5 4.125530904275365e-5 -0.00015342954429797828 -4.483346128836274e-5 -7.096589251887053e-5 -6.073797067074338e-6 -0.00010053201549453661 -0.00010655930236680433 -6.15896497038193e-5 8.832827734295279e-5 0.00027364963898435235 -0.00012950992095284164 -6.762120028724894e-5 8.232354593928903e-5 2.2172065655468032e-5 -3.8430920540122315e-5 -8.201197488233447e-6 -9.735457570059225e-5 -2.787660741887521e-5 -6.478169780166354e-6 -6.062388274585828e-5 -4.6261422539828345e-5 8.735960000194609e-5 -6.618558109039441e-5 -0.0002134224632754922 -6.240004586288705e-5; 1.8712698874878697e-5 -7.440827175742015e-5 0.00029201566940173507 -2.863066220015753e-5 6.932942778803408e-5 0.00014232436660677195 -0.00019182084361091256 -4.6148517867550254e-5 -5.6717854022281244e-5 -2.7183739803149365e-5 -0.00010045400995295495 -1.8461547369952314e-5 -1.2250801773916464e-5 7.860110054025427e-5 -0.00017421798838768154 -2.8133452360634692e-5 -1.286677252210211e-5 -7.1782196755521e-5 7.654032378923148e-5 6.51961227049469e-6 -4.5383476390270516e-5 0.00010798455332405865 -8.371622971026227e-5 -0.00015093969705048949 8.582270675105974e-5 -9.777774539543316e-5 0.00011555780656635761 -0.0001479965285398066 9.706829587230459e-5 -0.00010697705147322267 1.5245351278281305e-5 -0.00017884625412989408; -0.00017383236263412982 -5.974504892947152e-5 -1.2803261597582605e-5 -5.511122799362056e-5 0.00013899136683903635 0.00021640084742102772 -0.00014166758046485484 -8.383496606256813e-5 -4.518634523265064e-5 -2.6428517230669968e-5 3.522037877701223e-5 6.7982768996444065e-6 -8.228106889873743e-5 2.096188472933136e-5 8.631224773125723e-5 8.758527656027582e-6 -5.175305705051869e-5 -2.1218023903202266e-5 0.00011713910498656332 2.954444062197581e-5 -9.442470400244929e-6 0.00012968284136150032 -0.00016110538854263723 -5.222010804573074e-5 0.00016915859305299819 0.00013640450197272003 -4.7265271859942004e-5 1.95174798136577e-5 -6.046290945960209e-5 0.00011957895912928507 3.4745728044072166e-5 -6.35463438811712e-5; 8.06867828941904e-5 0.0001756239653332159 -0.0002319045743206516 0.00014578977425117046 4.568215808831155e-5 -0.00022750275093130767 -6.916403799550608e-5 2.3110917027224787e-5 -7.977746281540021e-5 -1.5075160263222642e-5 0.0001121444656746462 -0.00012073534890078008 -9.054144175024703e-5 2.2484920918941498e-5 6.208783247529936e-7 -0.00012633175356313586 -6.399908306775615e-5 0.00013822865730617195 -2.6126306693186052e-5 -2.9908922442700714e-5 5.989531200611964e-5 9.790054900804535e-5 -6.527340883621946e-5 6.922738248249516e-5 -0.0001348140649497509 -2.431534312563599e-6 0.00010837366426130757 -0.00012139607861172408 3.207100462532253e-6 -7.396096043521538e-5 -4.513431485975161e-5 -3.861175355268642e-5; 3.393736187717877e-5 0.0001571300090290606 0.00012165582302259281 -0.00021106937492731959 6.312149344012141e-5 -9.40415047807619e-5 0.00012060876179020852 -7.327913772314787e-5 7.751908924547024e-6 -9.998677705880255e-5 0.0001556042698211968 4.94304476887919e-5 0.00010063843365060166 0.000137703274958767 -1.453668392059626e-5 0.00010045764793176204 7.186872244346887e-5 -0.000178393442183733 -4.2791500163730234e-5 8.733919821679592e-5 6.633307930314913e-5 -0.00011306180385872722 -0.00012462360609788448 -9.766353468876332e-5 -2.198460788349621e-5 6.0756836319342256e-5 4.483139491640031e-5 8.12311191111803e-6 -5.5434589739888906e-5 -2.2514970623888075e-5 -0.00014343556540552527 -7.642344280611724e-5; -0.00014609213394578546 -8.550441998522729e-5 -0.00010906913666985929 0.00020463114196900278 0.00021621711493935436 -7.303435995709151e-5 2.0264811610104516e-5 -5.146263356436975e-5 -4.1726270865183324e-5 4.578151856549084e-5 -5.498086102306843e-5 0.00015476766566280276 -3.000687138410285e-5 0.00016163008695002645 3.320113683003001e-5 -6.088503141654655e-5 3.0233266443246976e-5 -5.52150777366478e-5 0.0002617643913254142 3.3377957151969895e-5 5.76113052375149e-5 -5.466677248477936e-5 -6.885980110382661e-5 -5.559020064538345e-5 0.00011739422916434705 0.000238279506447725 -0.0001586396392667666 4.291633376851678e-5 -9.425900498172268e-5 -8.032588084461167e-5 4.695042662206106e-5 5.917173621128313e-5; -5.143713133293204e-5 -7.920148345874622e-5 -9.399648115504533e-5 7.072796870488673e-5 6.639788625761867e-5 1.835748480516486e-5 -9.203385707223788e-5 -0.00010879456385737285 0.0001255818788195029 8.460958633804694e-5 -0.0001483020605519414 -4.0619008359499276e-5 -3.520726386341266e-5 3.396058309590444e-5 -0.0001080074580386281 -0.0002028288581641391 2.4710394086469023e-7 -4.31765838584397e-5 -0.00010651390039129183 2.362082523177378e-5 -2.8761610337824095e-6 -0.00015371180779766291 -8.663119660923257e-5 -3.6186516808811575e-5 -3.5602748539531603e-5 -1.886068821477238e-5 4.8293284635292366e-5 -4.814615749637596e-5 0.00011891397298313677 5.573307862505317e-5 3.8735477573936805e-5 -3.1584626412950456e-5; -2.260479232063517e-5 -9.444390161661431e-5 0.00015723273099865764 -1.991804310819134e-5 7.746560004306957e-5 3.4090156987076625e-5 -0.0001354694104520604 -0.00015931666712276638 -0.00021294232283253223 -6.810396735090762e-5 3.435415783314966e-5 -3.697948341141455e-5 0.00010238096729153767 0.00018340523820370436 0.00012042144953738898 -0.0001794673880795017 8.432318281847984e-5 0.00020150448835920542 -2.0352039427962154e-5 0.0001652283244766295 9.080890595214441e-5 -2.5497760361758992e-5 -0.00017544935690239072 0.0001310743682552129 -7.794196426402777e-5 -0.00011260640894761309 -0.00010212933557340875 -2.561869223427493e-5 -7.617218943778425e-5 -4.938619167660363e-5 -0.0001143684348789975 4.2881012632278726e-5; 5.107058314024471e-5 -8.180704026017338e-5 -9.305070125265047e-5 -1.8658678527572192e-5 7.146140706026927e-5 0.00011800940410466865 8.685049397172406e-5 9.368828614242375e-5 5.3249368647811934e-5 -4.097892087884247e-5 -2.4470869902870618e-5 -2.484190190443769e-5 -4.6810979256406426e-5 0.00018394584185443819 -8.569837518734857e-6 -3.1463139748666435e-5 0.0001462768850615248 -9.377640708407853e-6 6.887983909109607e-5 -5.8961031754733995e-5 7.392498082481325e-5 5.7918241509469226e-5 -4.535800326266326e-5 0.00010579131776466966 9.150545520242304e-5 -6.36400145594962e-5 0.00011414592881919816 2.077369936159812e-5 -5.2916664571966976e-5 -7.1490389927930664e-6 -5.537572724279016e-5 -8.087160676950589e-5; -4.175087087787688e-5 -9.293264884036034e-5 0.00023400339705403894 -2.8112342988606542e-5 -0.00015360076213255525 6.111040420364588e-5 -0.00015173325664363801 -7.901318167569116e-5 -0.00011086053564213216 0.0001163159467978403 -6.741469405824319e-5 -0.00011898295633727685 6.319095700746402e-5 -1.1270703907939605e-5 -3.6331832234282047e-5 -6.790777842979878e-5 -6.895367550896481e-5 -0.0002777258341666311 -8.059383253566921e-5 0.000105883322248701 0.0001181109546450898 0.00011533292126841843 0.00010384234337834641 -4.875091326539405e-5 -0.00016394061094615608 5.7716086303116754e-5 0.00013010275142733008 6.719410885125399e-5 -4.815046850126237e-5 0.00012288871221244335 -3.951174585381523e-5 0.00013423865311779082; 4.751199594466016e-5 -7.255500531755388e-5 -0.00012853130465373397 -7.943678610899951e-6 9.889062494039536e-5 -1.8899288988905028e-5 -6.926897913217545e-5 1.1490224096633028e-6 9.685733675723895e-5 8.413713658228517e-6 2.3920460080262274e-5 -2.1614960132865235e-5 4.3591098801698536e-5 1.901593350339681e-5 -2.1413785361801274e-5 -6.986526568653062e-5 -0.00015239862841553986 2.9802195058437064e-5 3.415209357626736e-5 2.286641756654717e-5 4.034837911603972e-5 2.866843897209037e-5 0.00018125430506188422 6.62070742691867e-5 5.649964077747427e-5 -1.2737948054564185e-5 8.680912287672982e-5 4.219630500301719e-5 9.141402551904321e-5 0.00011624008766375482 3.5815844512399053e-7 -0.00012488345964811742; -0.00010706039029173553 -6.441293226089329e-5 -6.967299123061821e-5 -4.8975991376210004e-5 8.150134078732663e-8 -2.4615343136247247e-5 -0.00010863552597584203 -8.546017488697544e-5 1.8646818716661073e-5 -1.7390480934409425e-5 3.4031845643767156e-6 -6.790846964577213e-5 7.817523874109611e-5 -1.3465097254083958e-5 -6.435852264985442e-5 -2.4385159122175537e-5 8.152233203873038e-5 -8.602481830166653e-5 3.688053766381927e-5 -4.5904134822194465e-6 -3.1499705073656514e-5 -1.8185251974500716e-5 -0.0001755994453560561 -3.694088809425011e-5 -2.1071979062980972e-5 2.460761425027158e-5 3.8020414649508893e-5 7.762497261865065e-5 5.932252679485828e-5 -0.00014880280650686473 -0.00017587367619853467 9.424422751180828e-5; -4.682713552028872e-5 -1.146883823821554e-5 1.1648739928205032e-5 0.0001158402010332793 7.883953003329225e-6 -5.800084909424186e-5 -0.00015876823454163969 5.045838770456612e-5 -0.0001812069967854768 1.5050575711939018e-5 5.313777000992559e-5 -0.0002211949322372675 6.529779966513161e-6 -0.00020138885884080082 -1.45223402796546e-5 8.318005711771548e-5 -0.00012183159560663626 5.1155133405700326e-5 -7.0288829192577396e-6 -2.294810656167101e-5 -7.64497890486382e-5 8.799134229775518e-5 0.00014802000077906996 4.8733065341366455e-5 -2.4215441953856498e-5 -9.527472866466269e-5 9.174767910735682e-5 -5.138023880135734e-6 -0.00013282531290315092 2.6118914320250042e-5 -8.209454972529784e-5 -7.768507930450141e-5; 3.490863946353784e-6 0.00022217075456865132 5.8998331951443106e-5 -0.00010397886944701895 1.3332826711121015e-5 -1.149945910583483e-5 3.22769810736645e-5 5.472335033118725e-5 4.1158473322866485e-5 -1.5073617760208435e-5 0.00010191957699134946 -0.000175659210071899 0.00013335059338714927 0.0001761217718012631 -0.0001205310327350162 -4.9752274208003655e-5 -0.00011244176857871935 4.072292995260796e-6 0.00010515073518035933 0.0001076158951036632 7.813119736965746e-5 3.1912939448375255e-5 0.00018457407713867724 -0.00013643165584653616 1.4855037989036646e-5 0.00019894006254617125 -1.7730366380419582e-5 8.651030657347292e-5 1.873005930974614e-5 0.00013409156235866249 -6.0886304709129035e-6 7.642319542355835e-5; 5.0080823712050915e-5 1.2904572031402495e-5 0.0001000757547444664 -5.15811598233995e-6 8.031712786760181e-5 9.04594489838928e-5 0.00015096952847670764 -0.00020038739603478462 -2.169524486816954e-5 -9.873550152406096e-5 2.5714673029142432e-5 0.0002036928344750777 -7.464663940481842e-5 -8.577438711654395e-5 4.0858067222870886e-5 -0.00010886756354011595 0.00011964718578383327 -3.476634810795076e-5 -0.0001331651146756485 1.4343697330332361e-5 -4.2422660044394433e-5 0.00015294480544980615 -0.0001731839292915538 8.208437975554261e-6 -1.1641293895081617e-5 -3.135771112283692e-5 0.00011026945867342874 0.0001474252640036866 8.51102449814789e-5 -1.0362523426010739e-5 -0.0002238245797343552 -5.105686068418436e-5; -0.00018035683024208993 -2.771221443254035e-5 -1.0907690011663362e-5 0.00011432961036916822 0.00014520859986077994 -3.0128467187751085e-5 -6.851187936263159e-5 0.00012034800602123141 5.171623342903331e-5 -5.885587597731501e-5 0.00012918470019940287 0.00024216051679104567 3.5891793231712654e-5 -0.0001069385398295708 1.5688942767155822e-6 -9.737899381434545e-5 -9.249980939785019e-5 2.2374821128323674e-5 1.5046087355585769e-5 0.0002259209140902385 0.00010478912008693442 -0.00015276321209967136 9.012196096591651e-5 -2.4467421098961495e-5 -0.00013231419143266976 -0.00013026180386077613 0.00011204244219698012 -3.5388379728829022e-6 -8.951786730904132e-5 4.8070625780383125e-5 2.7315384159010137e-6 8.020049426704645e-5; 8.528858597856015e-5 -3.44011714332737e-5 -5.555724783334881e-5 8.172654634108767e-5 3.460996595094912e-5 -1.7090103938244283e-5 -1.5557448023173492e-6 -2.982441037602257e-5 5.798785787192173e-5 -0.00014796860341448337 2.2409290977520868e-5 8.994695235742256e-6 -2.544479320931714e-5 4.095882104593329e-5 -7.624573481734842e-5 -0.00011693254782585427 -7.123567047528923e-5 4.563619222608395e-5 5.527213943423703e-6 8.108011388685554e-5 0.0002467887825332582 0.00019591496675275266 3.785571607295424e-5 -1.1081031516368967e-5 -6.089936141506769e-5 -2.1687183107133023e-5 -6.195715832291171e-5 0.00012107088696211576 5.9502657677512616e-5 9.347119339508936e-5 -3.039429611817468e-5 -7.153465412557125e-5; -4.562505182548193e-6 4.783416079590097e-5 0.00013670066255144775 -1.3297492841957137e-5 6.880383443785831e-5 1.5581073967041448e-5 0.00011613401147769764 -8.056028309511021e-5 -2.22245798795484e-5 -8.68488714331761e-5 0.00021297844068612903 0.00017712700355332345 -0.0001137745421146974 0.00010827670485014096 9.664745448390022e-5 -9.606093044567388e-6 -4.2647807276807725e-5 8.383241947740316e-5 -9.457893611397594e-5 -3.799515252467245e-5 0.00015998526941984892 5.630689884128515e-6 9.380091796629131e-5 -4.82647810713388e-5 0.00010404006752651185 -9.525820496492088e-5 0.0001316819980274886 -5.025450809625909e-5 -7.970451406436041e-5 0.0001830350374802947 0.0001004771693260409 -8.443968795290857e-7; 5.314400186762214e-5 -2.1409608962130733e-5 0.00011026807624148205 -3.131706762360409e-5 0.0001492581213824451 -4.359057129477151e-5 5.238928133621812e-5 2.8074779038433917e-5 -0.00013524368114303797 1.8790429749060422e-5 -0.0001770139206200838 -0.0001236808457178995 -0.00014045897114556283 -0.00011912620539078489 -0.00019324684399180114 3.968381497543305e-5 -2.01551265490707e-5 -7.96887106844224e-5 0.0002377843193244189 0.00013058062177151442 7.060010830173269e-5 8.190857624867931e-5 -0.00020509614842012525 -8.293686551041901e-5 -6.77060279485886e-6 -9.018112177727744e-5 7.08606094121933e-5 -4.115589035791345e-5 2.088630571961403e-5 5.0720671424642205e-5 1.4523232493957039e-5 -0.0001243379811057821; -1.4101491615292616e-5 -6.338117964332923e-5 6.895984552102163e-5 -3.188512346241623e-5 8.481836994178593e-5 -7.778898725518957e-5 1.5720843293820508e-5 8.978258847491816e-5 0.00015647847612854093 0.00011613625974860042 -9.673387830844149e-5 0.0001224111911142245 3.030294828931801e-5 -1.9349970898474567e-5 5.7149474741891026e-5 9.455143299419433e-5 -1.0076608305098489e-5 0.00013894405856262892 -0.0001516301417723298 -6.975488940952346e-5 -0.0001454385055694729 1.712965968181379e-5 0.00012774052447639406 0.00015430334315169603 -0.00019867894297931343 2.578655221441295e-5 0.00019528759003151208 -7.830418326193467e-5 4.097209239262156e-5 -0.00011121598072350025 0.00012568707461468875 4.955152689944953e-5; -0.00010344463225919753 -7.18437077011913e-5 -2.279093678225763e-5 -0.00013546677655540407 6.331226359179709e-6 8.011969475774094e-5 -1.0072208169731312e-5 -5.800990402349271e-5 -8.74482502695173e-5 0.00010243697761325166 -0.0002411823079455644 0.0001569731393828988 1.841999801399652e-5 -0.00018567388178780675 -4.6180091885617e-5 1.5837468936297228e-6 6.487918290076777e-5 -7.230017217807472e-5 -9.400511044077575e-5 0.00017190231301356107 0.00018656077736523002 2.3580070774187334e-5 -3.529333844198845e-5 0.00019954021263401955 -0.0002287498937221244 -8.693445124663413e-5 -2.3161835997598246e-5 0.00019836146384477615 -1.5394467482110485e-5 -7.779543375363573e-5 4.856991654378362e-5 3.644887692644261e-5; -0.0001249032939085737 0.00014808126434218138 3.1817744456930086e-5 0.00010703517182264477 -1.7187276171171106e-5 -7.647965685464442e-5 -0.00015130065730772913 1.3193133781896904e-5 -5.810613220091909e-5 4.284545502741821e-5 -8.551374048693106e-5 -4.163146877544932e-5 1.959103610715829e-5 1.5684949175920337e-5 5.4854786867508665e-5 4.6065717469900846e-5 9.056142152985558e-5 1.836550381995039e-6 8.685570355737582e-5 -3.3538861316628754e-5 9.650073479861021e-6 7.349961379077286e-5 4.4601401896215975e-5 -0.00010610842582536861 0.00011237036233069375 0.00011681157047860324 0.00015264701505657285 5.23434100614395e-5 -5.2041200433450285e-6 -4.2272899008821696e-5 -0.0001291171502089128 2.6472089302842505e-5; -7.938372618809808e-6 4.267384429113008e-5 -2.6617584808263928e-5 1.6567830243729986e-5 7.092107989592478e-5 6.202349322848022e-5 -2.7935060643358156e-5 8.376831829082221e-5 -8.123340376187116e-5 2.1210646082181484e-5 1.7968764950637706e-5 2.0627192498068325e-5 -5.363185846363194e-5 7.20743482816033e-5 1.4533075045619626e-5 -0.00014739266771357507 -1.6229159882641397e-5 -5.858178610651521e-6 -0.0001148699811892584 0.0001236443204106763 5.197422069613822e-5 -6.050817319191992e-5 2.633796066220384e-5 0.00010490320710232481 0.00016064550436567515 2.0535724615911022e-5 8.924282883526757e-5 5.522921128431335e-5 0.00020470467279665172 5.303609214024618e-5 -0.00011954976071137935 -0.00010241605923511088; 0.0001772364048520103 -2.4458275220240466e-5 -6.494705303339288e-5 -0.00018576138245407492 -2.0019782823510468e-5 8.960269951785449e-6 -3.229030335205607e-5 0.0002486307639628649 6.577227850357303e-6 6.219537317520007e-5 0.00016193471674341708 -3.0265424356912263e-5 2.0175482859485783e-5 3.6722540244227275e-5 0.0001926938712131232 1.4662384273833595e-5 0.00011587324843276292 0.00015950226224958897 2.7437157768872567e-5 4.739942596643232e-5 2.7846133434650255e-6 -5.662069816025905e-5 -2.9625402930832934e-6 9.826812311075628e-5 -0.00012319097004365176 -2.7650585252558812e-5 9.994065294449683e-6 0.00014062107948120683 6.166009552543983e-5 -2.380768819421064e-5 5.346381658455357e-5 -8.056754450080916e-5; -1.1458203516667709e-5 -5.571198562392965e-5 9.586469241185114e-5 -1.9776589397224598e-5 -5.975789827061817e-5 -0.0001353875413769856 -0.00011371760774636641 -2.0802312064915895e-5 -9.906904597301036e-5 -0.00016055587911978364 7.007410749793053e-5 -7.766038470435888e-5 1.876068563433364e-5 4.1213526856154203e-5 -4.022226858069189e-5 0.0001107260468415916 -0.00012354744831100106 8.556358079658821e-5 0.00012340200191829354 -2.834833367160172e-6 8.509084000252187e-5 -5.229003363638185e-5 -4.338217422628077e-6 0.00010785373160615563 1.2645864444493782e-5 2.1994295821059495e-5 -1.4214328075468075e-5 -6.74871844239533e-5 -0.00012045529729221016 0.00014635185652878135 -3.479025326669216e-5 9.42954357014969e-5], bias = [0.08470798283815384, 0.08645686507225037, -0.11841113120317459, -0.14723071455955505, -0.09460315108299255, -0.06184377893805504, 0.13951914012432098, 0.12184120714664459, 0.1390080600976944, 0.10256484150886536, 0.09120230376720428, -0.14918911457061768, 0.043181829154491425, -0.07329239696264267, 0.12129180133342743, -0.07931830734014511, -0.07969075441360474, -0.09501335024833679, -0.08094597607851028, -0.15352904796600342, 0.013736311346292496, -0.1691044569015503, 0.06395997107028961, -0.02185533195734024, -0.05005486309528351, 0.03820371627807617, 0.09452157467603683, 0.14302265644073486, -0.02337830886244774, -0.09521307796239853, 0.0981493890285492, 0.04544397443532944]), layer_4 = (weight = [-0.0002391026064287871 -4.4073083699913695e-5 -4.828174496651627e-5 0.0001276252296520397 -0.00018221705977339298 -9.343410965811927e-6 0.00017206929624080658 -5.150984725332819e-5 -1.916631617859821e-6 -3.23652129736729e-5 2.6117988454643637e-5 -1.836217961681541e-5 1.7979635231313296e-5 -6.0436996136559173e-5 -0.00014622995513491333 -1.0849271347979084e-5 -3.0017779863555916e-5 -7.928423292469233e-5 0.0001355437416350469 0.00011687533697113395 1.0570677659416106e-5 -1.8724649635259993e-5 -8.609983342466876e-5 -0.00010475828457856551 0.00028195392224006355 -6.889872020110488e-5 6.098030553403078e-6 5.1694521971512586e-5 8.372601769224275e-6 0.00017695038695819676 9.337200754089281e-5 -5.74830592086073e-5; 0.0001529790461063385 -0.00010769865184556693 4.105279458599398e-6 -0.00014262321928981692 -0.00015561023610644042 0.00016070155834313482 -3.608101906138472e-5 -2.5221104806405492e-5 -6.713333277730271e-5 -3.7640100345015526e-5 0.00018570487736724317 -0.0001235045347129926 0.00015100558812264353 5.446571594802663e-5 -1.3783012036583386e-5 4.6443990868283436e-5 -6.506982754217461e-5 9.527860675007105e-5 0.00029625839670188725 -1.9288274415885098e-5 5.950322520220652e-5 -8.325627277372405e-5 8.625080954516307e-5 -0.00010482496145414189 -6.054404366295785e-5 -3.2250805816147476e-5 -0.00015356196672655642 7.631258631590754e-5 -0.00018434594676364213 -0.00011259126040386036 0.00011350572458468378 -0.00010150185698876157], bias = [0.12655267119407654, -0.17264202237129211]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(compute_waveform(
    dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(ax, tsteps, waveform_nn_trained; marker=:circle,
        alpha=0.5, strokewidth=2, markersize=12)

    axislegend(ax, [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb)

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
  JULIA_CPU_THREADS = 16
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 16
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.