Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux, ComponentArrays, LineSearches, OrdinaryDiffEqLowOrderRK, Optimization,
    OptimizationOptimJL, Printf, Random, SciMLSensitivity
using CairoMakie
Precompiling Lux...
   3275.3 ms  ✓ KernelAbstractions
    772.8 ms  ✓ KernelAbstractions → LinearAlgebraExt
    709.7 ms  ✓ KernelAbstractions → EnzymeExt
   5368.2 ms  ✓ NNlib
    824.9 ms  ✓ NNlib → NNlibEnzymeCoreExt
    868.1 ms  ✓ NNlib → NNlibSpecialFunctionsExt
    922.4 ms  ✓ NNlib → NNlibForwardDiffExt
   6214.3 ms  ✓ LuxLib
   9153.4 ms  ✓ Lux
  9 dependencies successfully precompiled in 26 seconds. 101 already precompiled.
Precompiling ComponentArrays...
    872.3 ms  ✓ ComponentArrays
  1 dependency successfully precompiled in 1 seconds. 45 already precompiled.
Precompiling MLDataDevicesComponentArraysExt...
    499.7 ms  ✓ MLDataDevices → MLDataDevicesComponentArraysExt
  1 dependency successfully precompiled in 1 seconds. 48 already precompiled.
Precompiling LuxComponentArraysExt...
    494.7 ms  ✓ ComponentArrays → ComponentArraysOptimisersExt
   1390.6 ms  ✓ Lux → LuxComponentArraysExt
   2069.3 ms  ✓ ComponentArrays → ComponentArraysKernelAbstractionsExt
  3 dependencies successfully precompiled in 2 seconds. 112 already precompiled.
Precompiling OrdinaryDiffEqLowOrderRK...
    318.0 ms  ✓ FastPower
    329.0 ms  ✓ MuladdMacro
    492.4 ms  ✓ TruncatedStacktraces
    689.9 ms  ✓ FastBroadcast
    745.1 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsFastBroadcastExt
   5493.7 ms  ✓ DiffEqBase
   3950.3 ms  ✓ OrdinaryDiffEqCore
   1241.6 ms  ✓ OrdinaryDiffEqCore → OrdinaryDiffEqCoreEnzymeCoreExt
   3832.2 ms  ✓ OrdinaryDiffEqLowOrderRK
  9 dependencies successfully precompiled in 16 seconds. 91 already precompiled.
Precompiling FastPowerForwardDiffExt...
    647.2 ms  ✓ FastPower → FastPowerForwardDiffExt
  1 dependency successfully precompiled in 1 seconds. 27 already precompiled.
Precompiling ComponentArraysRecursiveArrayToolsExt...
    675.7 ms  ✓ ComponentArrays → ComponentArraysRecursiveArrayToolsExt
  1 dependency successfully precompiled in 1 seconds. 69 already precompiled.
Precompiling ComponentArraysSciMLBaseExt...
   1039.9 ms  ✓ ComponentArrays → ComponentArraysSciMLBaseExt
  1 dependency successfully precompiled in 1 seconds. 89 already precompiled.
Precompiling DiffEqBaseForwardDiffExt...
   1506.7 ms  ✓ DiffEqBase → DiffEqBaseForwardDiffExt
  1 dependency successfully precompiled in 2 seconds. 109 already precompiled.
Precompiling DiffEqBaseChainRulesCoreExt...
   1293.3 ms  ✓ DiffEqBase → DiffEqBaseChainRulesCoreExt
  1 dependency successfully precompiled in 2 seconds. 98 already precompiled.
Precompiling SparseArraysExt...
    928.6 ms  ✓ KernelAbstractions → SparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 27 already precompiled.
Precompiling DiffEqBaseSparseArraysExt...
   1342.7 ms  ✓ DiffEqBase → DiffEqBaseSparseArraysExt
  1 dependency successfully precompiled in 2 seconds. 101 already precompiled.
Precompiling SparseConnectivityTracerNNlibExt...
   1672.6 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerNNlibExt
  1 dependency successfully precompiled in 2 seconds. 47 already precompiled.
Precompiling SciMLSensitivity...
    386.4 ms  ✓ StructIO
    390.0 ms  ✓ HashArrayMappedTries
    417.4 ms  ✓ PoissonRandom
    755.3 ms  ✓ ResettableStacks
    705.9 ms  ✓ StructArrays → StructArraysGPUArraysCoreExt
    763.7 ms  ✓ PreallocationTools
    934.4 ms  ✓ Cassette
    372.1 ms  ✓ ScopedValues
    511.9 ms  ✓ FunctionProperties
   1419.7 ms  ✓ LLVMExtra_jll
   1446.4 ms  ✓ Enzyme_jll
   1801.2 ms  ✓ TimerOutputs
   1860.9 ms  ✓ IRTools
   1845.4 ms  ✓ DiffEqBase → DiffEqBaseDistributionsExt
   1927.6 ms  ✓ ObjectFile
   2747.9 ms  ✓ SciMLJacobianOperators
   4126.6 ms  ✓ DiffEqCallbacks
   5006.6 ms  ✓ Tracker
   3752.1 ms  ✓ DiffEqNoiseProcess
   1101.4 ms  ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
   1113.2 ms  ✓ FastPower → FastPowerTrackerExt
   6293.6 ms  ✓ Krylov
   1170.4 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
   1268.5 ms  ✓ ArrayInterface → ArrayInterfaceTrackerExt
   1374.9 ms  ✓ Tracker → TrackerPDMatsExt
   5784.3 ms  ✓ LLVM
   2418.1 ms  ✓ DiffEqBase → DiffEqBaseTrackerExt
   1736.8 ms  ✓ UnsafeAtomics → UnsafeAtomicsLLVM
  11872.0 ms  ✓ ArrayLayouts
    801.8 ms  ✓ ArrayLayouts → ArrayLayoutsSparseArraysExt
   4552.6 ms  ✓ GPUArrays
  14300.6 ms  ✓ ReverseDiff
   2409.4 ms  ✓ LazyArrays
   1291.1 ms  ✓ LazyArrays → LazyArraysStaticArraysExt
   3305.5 ms  ✓ FastPower → FastPowerReverseDiffExt
   3300.7 ms  ✓ PreallocationTools → PreallocationToolsReverseDiffExt
   3374.5 ms  ✓ ArrayInterface → ArrayInterfaceReverseDiffExt
   3473.6 ms  ✓ DifferentiationInterface → DifferentiationInterfaceReverseDiffExt
   4709.3 ms  ✓ DiffEqNoiseProcess → DiffEqNoiseProcessReverseDiffExt
   4796.2 ms  ✓ DiffEqBase → DiffEqBaseReverseDiffExt
  15627.1 ms  ✓ GPUCompiler
  14445.6 ms  ✓ LinearSolve
   1657.9 ms  ✓ LinearSolve → LinearSolveEnzymeExt
   3362.5 ms  ✓ LinearSolve → LinearSolveKernelAbstractionsExt
   3880.5 ms  ✓ LinearSolve → LinearSolveSparseArraysExt
  24660.1 ms  ✓ Zygote
   1510.9 ms  ✓ DifferentiationInterface → DifferentiationInterfaceZygoteExt
   1828.9 ms  ✓ Zygote → ZygoteTrackerExt
   3141.1 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsZygoteExt
   3374.9 ms  ✓ SciMLBase → SciMLBaseZygoteExt
   5287.3 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsReverseDiffExt
 186763.5 ms  ✓ Enzyme
   5511.3 ms  ✓ FastPower → FastPowerEnzymeExt
   5533.3 ms  ✓ Enzyme → EnzymeGPUArraysCoreExt
   5631.4 ms  ✓ DifferentiationInterface → DifferentiationInterfaceEnzymeExt
   5630.1 ms  ✓ QuadGK → QuadGKEnzymeExt
   5691.5 ms  ✓ Enzyme → EnzymeLogExpFunctionsExt
   5973.7 ms  ✓ Enzyme → EnzymeSpecialFunctionsExt
   8150.6 ms  ✓ Enzyme → EnzymeStaticArraysExt
  10539.1 ms  ✓ Enzyme → EnzymeChainRulesCoreExt
   7686.6 ms  ✓ DiffEqBase → DiffEqBaseEnzymeExt
  20577.9 ms  ✓ SciMLSensitivity
  62 dependencies successfully precompiled in 241 seconds. 214 already precompiled.
Precompiling PreallocationToolsSparseConnectivityTracerExt...
   1050.4 ms  ✓ PreallocationTools → PreallocationToolsSparseConnectivityTracerExt
  1 dependency successfully precompiled in 1 seconds. 44 already precompiled.
Precompiling LuxLibEnzymeExt...
   1261.1 ms  ✓ LuxLib → LuxLibEnzymeExt
  1 dependency successfully precompiled in 2 seconds. 132 already precompiled.
Precompiling LuxEnzymeExt...
   6642.3 ms  ✓ Lux → LuxEnzymeExt
  1 dependency successfully precompiled in 7 seconds. 148 already precompiled.
Precompiling OptimizationEnzymeExt...
  12966.4 ms  ✓ OptimizationBase → OptimizationEnzymeExt
  1 dependency successfully precompiled in 13 seconds. 111 already precompiled.
Precompiling MLDataDevicesTrackerExt...
   1163.8 ms  ✓ MLDataDevices → MLDataDevicesTrackerExt
  1 dependency successfully precompiled in 1 seconds. 60 already precompiled.
Precompiling LuxLibTrackerExt...
   1070.6 ms  ✓ LuxCore → LuxCoreArrayInterfaceTrackerExt
   3273.9 ms  ✓ LuxLib → LuxLibTrackerExt
  2 dependencies successfully precompiled in 4 seconds. 101 already precompiled.
Precompiling LuxTrackerExt...
   2047.7 ms  ✓ Lux → LuxTrackerExt
  1 dependency successfully precompiled in 2 seconds. 115 already precompiled.
Precompiling ComponentArraysTrackerExt...
   1152.0 ms  ✓ ComponentArrays → ComponentArraysTrackerExt
  1 dependency successfully precompiled in 1 seconds. 71 already precompiled.
Precompiling MLDataDevicesReverseDiffExt...
   3372.6 ms  ✓ MLDataDevices → MLDataDevicesReverseDiffExt
  1 dependency successfully precompiled in 4 seconds. 49 already precompiled.
Precompiling LuxLibReverseDiffExt...
   3293.3 ms  ✓ LuxCore → LuxCoreArrayInterfaceReverseDiffExt
   4114.8 ms  ✓ LuxLib → LuxLibReverseDiffExt
  2 dependencies successfully precompiled in 4 seconds. 99 already precompiled.
Precompiling ComponentArraysReverseDiffExt...
   3421.5 ms  ✓ ComponentArrays → ComponentArraysReverseDiffExt
  1 dependency successfully precompiled in 4 seconds. 57 already precompiled.
Precompiling OptimizationReverseDiffExt...
   3277.0 ms  ✓ OptimizationBase → OptimizationReverseDiffExt
  1 dependency successfully precompiled in 4 seconds. 122 already precompiled.
Precompiling LuxReverseDiffExt...
   4250.3 ms  ✓ Lux → LuxReverseDiffExt
  1 dependency successfully precompiled in 5 seconds. 116 already precompiled.
Precompiling MLDataDevicesZygoteExt...
   1536.8 ms  ✓ MLDataDevices → MLDataDevicesGPUArraysExt
   1554.3 ms  ✓ MLDataDevices → MLDataDevicesZygoteExt
  2 dependencies successfully precompiled in 2 seconds. 109 already precompiled.
Precompiling LuxZygoteExt...
   1637.2 ms  ✓ WeightInitializers → WeightInitializersGPUArraysExt
   2782.6 ms  ✓ Lux → LuxZygoteExt
  2 dependencies successfully precompiled in 3 seconds. 167 already precompiled.
Precompiling ComponentArraysZygoteExt...
   1578.7 ms  ✓ ComponentArrays → ComponentArraysZygoteExt
   1816.4 ms  ✓ ComponentArrays → ComponentArraysGPUArraysExt
  2 dependencies successfully precompiled in 2 seconds. 117 already precompiled.
Precompiling OptimizationZygoteExt...
   2084.0 ms  ✓ OptimizationBase → OptimizationZygoteExt
  1 dependency successfully precompiled in 2 seconds. 162 already precompiled.
Precompiling CairoMakie...
    378.6 ms  ✓ IndirectArrays
    401.7 ms  ✓ PolygonOps
    414.0 ms  ✓ Contour
    423.1 ms  ✓ GeoFormatTypes
    417.1 ms  ✓ PCRE2_jll
    416.7 ms  ✓ LazyModules
    434.1 ms  ✓ TriplotBase
    441.9 ms  ✓ TensorCore
    453.0 ms  ✓ StableRNGs
    479.4 ms  ✓ Extents
    490.7 ms  ✓ RoundingEmulator
    501.2 ms  ✓ Observables
    565.1 ms  ✓ TranscodingStreams
    346.8 ms  ✓ CRC32c
    458.0 ms  ✓ Inflate
    786.9 ms  ✓ Grisu
   1045.6 ms  ✓ Format
    878.7 ms  ✓ Glib_jll
   1519.0 ms  ✓ AdaptivePredicates
   1046.0 ms  ✓ GeoInterface
    786.0 ms  ✓ Cairo_jll
   1781.4 ms  ✓ ColorVectorSpace
    820.3 ms  ✓ HarfBuzz_jll
    738.1 ms  ✓ ColorVectorSpace → SpecialFunctionsExt
   2485.9 ms  ✓ IntervalArithmetic
    505.0 ms  ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
    760.4 ms  ✓ libass_jll
    789.4 ms  ✓ Pango_jll
    954.3 ms  ✓ FFMPEG_jll
   1289.6 ms  ✓ Cairo
   3436.8 ms  ✓ ColorSchemes
   3562.2 ms  ✓ ExactPredicates
   7883.6 ms  ✓ Automa
   9544.6 ms  ✓ GeometryBasics
   5222.3 ms  ✓ DelaunayTriangulation
   1029.5 ms  ✓ Packing
   1290.5 ms  ✓ ShaderAbstractions
   1895.7 ms  ✓ FreeTypeAbstraction
   7886.3 ms  ✓ PlotUtils
   3745.1 ms  ✓ MakieCore
   5038.7 ms  ✓ GridLayoutBase
  13962.7 ms  ✓ ImageCore
   2037.8 ms  ✓ ImageBase
   2352.3 ms  ✓ WebP
   3055.9 ms  ✓ PNGFiles
   3157.8 ms  ✓ JpegTurbo
   3272.1 ms  ✓ Sixel
   2140.3 ms  ✓ ImageAxes
   1222.8 ms  ✓ ImageMetadata
   8914.9 ms  ✓ MathTeXEngine
   1889.4 ms  ✓ Netpbm
  44109.8 ms  ✓ TiffImages
   1203.5 ms  ✓ ImageIO
 104735.9 ms  ✓ Makie
  80460.1 ms  ✓ CairoMakie
  55 dependencies successfully precompiled in 232 seconds. 217 already precompiled.
Precompiling ZygoteColorsExt...
   1773.3 ms  ✓ Zygote → ZygoteColorsExt
  1 dependency successfully precompiled in 2 seconds. 106 already precompiled.
Precompiling DiffEqBaseUnitfulExt...
   1305.7 ms  ✓ DiffEqBase → DiffEqBaseUnitfulExt
  1 dependency successfully precompiled in 2 seconds. 98 already precompiled.
Precompiling NNlibFFTWExt...
    900.6 ms  ✓ NNlib → NNlibFFTWExt
  1 dependency successfully precompiled in 1 seconds. 54 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
    524.1 ms  ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
    715.0 ms  ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
  2 dependencies successfully precompiled in 1 seconds. 42 already precompiled.
Precompiling IntervalArithmeticRecipesBaseExt...
    853.8 ms  ✓ IntervalArithmetic → IntervalArithmeticRecipesBaseExt
  1 dependency successfully precompiled in 1 seconds. 31 already precompiled.
Precompiling SciMLBaseMakieExt...
   8238.8 ms  ✓ SciMLBase → SciMLBaseMakieExt
  1 dependency successfully precompiled in 9 seconds. 306 already precompiled.

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end
one2two (generic function with 1 method)

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params = nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end
soln2orbit (generic function with 2 methods)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end
d_dt (generic function with 1 method)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass = 1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass = 1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass = 1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params = nothing) where {T}
    @assert mass_ratio  1 "mass_ratio must be <= 1"
    @assert mass_ratio  0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end
compute_waveform (generic function with 2 methods)

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length = datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat = tsteps, dt, adaptive = false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel = "Time", ylabel = "Waveform")

    l = lines!(ax, tsteps, waveform; linewidth = 2, alpha = 0.75)
    s = scatter!(ax, tsteps, waveform; marker = :circle, markersize = 12, alpha = 0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defiing a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(
    Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight = truncated_normal(; std = 1.0e-4), init_bias = zeros32),
    Dense(32 => 32, cos; init_weight = truncated_normal(; std = 1.0e-4), init_bias = zeros32),
    Dense(32 => 2; init_weight = truncated_normal(; std = 1.0e-4), init_bias = zeros32)
)
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[1.1684905f-5; -4.7562342f-5; -9.8346456f-5; 9.2707414f-5; -5.5832385f-5; 5.513598f-5; -5.315212f-5; 0.00016854524; -0.00019818037; 4.0237865f-5; 6.3435764f-6; 6.203809f-5; 7.028735f-5; -8.3079656f-5; 0.000116631614; 2.3536568f-6; -7.852616f-5; -1.45475415f-5; 0.00014218158; -8.9627465f-5; -5.7716214f-5; -4.2395473f-5; 5.547956f-5; 3.786926f-5; 8.198459f-5; -2.1605754f-5; -0.00019508181; 9.921214f-7; -0.00014845224; 0.00025220696; 0.00012981896; 0.00010813904;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[2.088722f-5 -1.0384153f-5 -9.923339f-5 -0.00023337625 -6.719556f-5 0.00011228717 -2.7750459f-5 -0.00013735746 -9.937435f-5 -1.6280892f-5 6.6701105f-5 -4.021812f-6 -8.6667635f-5 2.8168848f-5 7.4921976f-5 5.6510116f-6 4.4864424f-5 5.2468844f-5 -7.7628385f-5 -2.2529988f-5 -6.2222307f-7 4.8999136f-6 -1.5005712f-5 0.00011624689 -0.00012563949 -0.00023816775 0.00011265791 0.00010839347 4.4798162f-5 5.2900505f-5 0.00013163808 9.073432f-5; -5.9917034f-5 6.0145398f-5 5.0256785f-5 6.9250564f-5 -4.2841f-5 -0.00011384397 7.8834426f-5 0.00013481764 0.000104489096 0.00011217392 -2.8544935f-5 -0.00012977341 9.19914f-5 -1.2791091f-5 -1.1380922f-5 6.376344f-5 -0.00012432106 5.937397f-5 -5.766167f-6 -6.6287226f-5 -0.00016995014 -1.52461525f-5 5.8879883f-5 2.2334365f-5 -2.5447189f-5 -0.00012901193 -1.0667334f-5 -5.964425f-5 -8.5390355f-5 0.00018054189 -2.7473172f-5 9.24942f-5; 0.00015035793 0.00013564333 4.5188306f-5 2.025252f-5 -7.2929633f-6 3.6066842f-5 2.0440773f-5 -6.8673735f-6 -7.81625f-5 0.000100382495 7.814525f-5 5.624272f-5 2.5178277f-5 -8.600803f-6 0.00010461494 8.415992f-6 -0.00014457712 5.4429114f-5 6.919445f-5 -0.00020711031 -0.00021791106 8.212324f-5 8.130227f-5 4.0735235f-5 -4.417524f-6 -2.117888f-5 -5.716974f-5 -9.656721f-5 0.00014454452 -0.0001850033 0.00019189817 0.00011108199; -6.108972f-5 -0.000109225664 0.00021448385 7.722318f-5 -4.049551f-5 -0.00014348871 0.00012652583 3.7157438f-8 -2.088835f-5 1.8641506f-6 -7.487372f-5 -1.2506324f-5 1.7930457f-5 4.0512086f-6 -6.356351f-5 2.1949558f-5 -5.950191f-5 -2.8600976f-5 0.00021861444 -9.474737f-6 -3.390166f-5 -3.7145917f-5 -7.526067f-5 0.000117903546 -7.879762f-5 8.10823f-5 0.00010687891 -1.4865599f-5 -4.841093f-5 0.00016363944 -5.7125017f-5 3.232548f-5; 4.6496203f-5 -7.0738584f-5 -2.8979389f-6 -0.00021087655 2.350714f-5 0.00010442293 3.5953303f-7 3.2616765f-5 -8.640114f-5 -3.786721f-7 0.0001103531 -7.6921024f-5 3.483132f-5 -0.00013451334 0.000109874134 2.6539385f-6 -1.8588718f-5 -4.1370895f-5 2.5726038f-6 -6.1614104f-5 0.00010754753 -6.403596f-6 2.6972091f-5 -7.405993f-5 -6.597882f-5 -0.000121237004 -0.00011906455 -2.7027454f-5 -0.00016030221 4.389199f-6 2.5633497f-5 2.1793736f-5; 4.4917917f-5 4.681705f-5 1.1736051f-5 -5.8095982f-5 -0.00012784364 9.507624f-5 0.00015312394 7.645594f-5 -1.4698146f-5 2.1220058f-5 -7.457968f-5 -2.836059f-5 0.00010177395 -2.5321466f-5 -0.00017961078 8.017614f-5 -0.00016125012 0.00018662545 0.00015553633 -9.442007f-6 6.758729f-5 2.5059857f-5 1.4374657f-5 0.00010774504 2.2928718f-5 9.3382754f-5 8.111662f-5 1.6176327f-5 3.3414515f-5 4.6632453f-5 -6.97374f-5 9.259694f-5; -1.9478028f-5 -0.00010730424 0.00012286246 -5.6914545f-5 0.00010645061 0.00020565206 -1.3026944f-5 2.5445072f-5 8.4614316f-5 -0.00016345148 -5.857426f-5 -8.962069f-5 0.00010963009 -1.7391964f-6 -6.905813f-5 -3.4827477f-5 0.00013861911 -1.0855554f-5 6.3835694f-5 -0.00010585937 -3.4714532f-5 0.00014418727 -5.8282098f-5 4.3082204f-5 5.9974614f-6 -2.0596915f-5 -1.0628148f-5 -8.515751f-5 8.221581f-5 6.0191367f-5 0.00017254188 1.4208019f-5; -9.793665f-5 9.85109f-5 -0.00013970058 1.9002088f-6 -0.00011692633 -3.3321852f-5 -9.675481f-5 -0.00010422697 3.0022064f-5 -0.00015108263 1.0254031f-5 8.69494f-5 1.0576778f-5 -1.3144847f-5 -1.1405285f-5 0.000101609476 6.872323f-5 -4.6859037f-5 -4.514674f-5 -2.6893713f-5 -0.00010194663 -5.1778632f-5 9.962426f-5 6.4364496f-5 3.189322f-5 -5.504246f-5 -4.3336826f-5 1.9031831f-5 6.1344086f-5 5.0101935f-5 -9.705026f-5 -2.9733777f-5; -9.650133f-5 4.1168387f-6 -2.8238323f-6 -2.430767f-5 9.388309f-5 3.5748755f-5 0.00013382795 0.00012501968 -8.6422f-5 4.2435084f-5 -1.4031241f-6 -0.00016365142 -5.938125f-5 3.810653f-5 -0.00022335823 -4.761851f-5 -2.7196187f-5 -5.96799f-5 5.070475f-5 5.0023715f-5 -2.4601006f-5 -8.2677136f-5 -0.00016701513 2.0686764f-5 6.0744984f-5 3.6950105f-6 -2.682923f-5 -0.00020156136 -0.0003043549 5.9197115f-5 -0.00011349351 -6.613376f-5; 6.894893f-5 6.588276f-5 -5.356462f-5 6.814854f-6 -0.0001754245 -0.00012160799 -9.635787f-5 9.650914f-6 4.911041f-5 -4.9894836f-5 5.2904165f-6 0.000108406515 -3.87966f-5 0.00015994598 -5.4891403f-5 -0.00016180566 -7.483392f-5 -1.6531298f-5 -6.866813f-5 0.00015062747 0.0002472933 6.797544f-5 0.00010658516 -0.00012163704 -6.80141f-5 -1.8545854f-5 -0.0001241952 1.4772824f-6 6.784644f-5 -0.00012581753 2.6952694f-5 -6.372112f-6; -0.00013655233 3.7522732f-5 0.00018782253 4.687737f-6 -0.00011081433 2.0609612f-5 -0.0001777138 -0.00015084725 1.9730032f-5 9.007675f-5 -5.7420588f-5 1.0713254f-5 0.0001615104 -0.00010087456 -8.372983f-5 4.691307f-5 -1.3770388f-5 7.275328f-5 1.3486294f-5 7.704387f-5 -8.012146f-5 0.00013842934 -0.00013188546 -2.8233819f-5 4.339041f-5 -3.769436f-5 0.00015222901 9.1477756f-5 -5.630163f-5 3.9359606f-5 -8.705842f-5 6.760031f-5; 6.163674f-5 -2.8829414f-5 -1.8239165f-5 -3.0283236f-5 -9.3758135f-5 -9.55332f-5 3.447313f-5 0.0001638876 5.101372f-5 1.2703637f-5 4.5778885f-5 0.00011647861 -9.642962f-5 2.7403887f-5 -0.00010628428 -2.8460927f-5 0.00016994674 6.217895f-5 7.4884716f-5 -6.6666357f-6 -9.00689f-5 0.00012575842 1.3624942f-5 6.2689236f-5 8.0599086f-5 3.7000285f-5 1.921279f-6 -0.00016597233 7.059613f-5 7.6187683f-7 -5.302445f-5 -6.9140566f-5; -0.00017198925 8.337728f-6 -3.81572f-5 -0.00015338157 0.00012463053 0.00020534146 9.807124f-6 -9.400505f-5 9.239298f-5 4.128268f-5 7.536292f-5 7.4753894f-5 -9.117339f-5 -0.00014074608 -0.00013111917 -0.00016392863 -0.00012497004 0.00015172042 -0.00010060856 0.00012660786 7.676104f-5 -9.60528f-5 0.00011592054 0.00011354655 2.5579095f-5 6.36368f-5 3.7376783f-5 -0.00013659117 -0.00012450157 -2.5330131f-5 3.835039f-5 2.0366699f-5; -0.00011352322 -0.00029400724 -0.00014187387 1.8246166f-5 6.515616f-6 -1.6686854f-5 -2.6123662f-5 2.614345f-5 -9.303158f-5 -5.2156913f-5 2.1189771f-5 -6.927444f-5 2.3605682f-5 7.4393596f-5 2.3380811f-5 3.240877f-5 -0.00015396149 -0.00013277736 0.0002172182 6.079936f-5 2.8649636f-5 0.0001943015 -8.612096f-5 8.098408f-5 -8.4665415f-5 -7.629479f-5 -3.6916976f-5 -0.0001895155 0.0001015209 1.2828856f-6 -2.317976f-5 -1.911416f-5; -9.9066434f-5 0.00016756798 0.00010608431 -0.00013052842 2.6664935f-5 0.000102594226 1.815873f-5 -8.787213f-5 -5.1714676f-5 5.6882767f-5 -0.00012741318 -5.509213f-5 -6.9805625f-5 1.6116175f-5 -8.277243f-5 -1.6466005f-5 0.00019177629 -0.00011531458 6.1748986f-5 2.2701156f-6 -1.913709f-5 -1.3781174f-5 -9.44984f-6 9.9331875f-5 3.363884f-5 -0.00027377106 -2.1996577f-5 -6.243668f-5 7.4362447f-6 0.00010001255 1.3472996f-5 2.5645159f-5; -9.716363f-5 -9.369649f-5 -2.0494351f-6 -1.0757164f-5 0.00012251668 -0.00022005131 -0.00012229924 -3.517816f-5 5.0233848f-5 -5.631348f-5 -3.710014f-5 6.337054f-5 0.00017728902 -0.00021947664 -2.7259945f-5 8.6098466f-5 8.553158f-5 -0.00017367571 -0.0001051632 6.855217f-5 3.1212236f-5 4.539996f-6 0.00017897775 -2.9216517f-7 2.4274692f-5 -6.49952f-6 0.0002657032 0.00017111246 2.4510406f-5 -3.9071925f-5 7.1243826f-6 1.5936914f-5; 9.4743235f-5 4.613115f-5 -0.0002372199 4.7830665f-5 9.181537f-5 3.032763f-5 -5.033251f-5 -4.219445f-5 -0.000118944 9.3820636f-5 6.245642f-5 5.0194876f-7 2.5728323f-5 2.079164f-5 7.979552f-5 0.0002007488 0.00012969869 -5.710752f-6 -1.1681775f-5 -3.7868685f-5 -2.19601f-5 7.475647f-5 -2.0025482f-5 5.8150148f-5 2.0468127f-5 0.00011794872 -0.0001309858 -4.875218f-5 0.00011613721 -2.2602374f-6 0.00020198792 8.7165035f-6; -2.615815f-7 0.00012281672 7.151949f-5 2.2268687f-5 6.343525f-5 0.00017087242 -0.00019616967 -2.7473061f-5 -6.7667956f-5 -6.887006f-5 -1.710407f-5 -8.900492f-5 -9.44963f-6 -0.00015237012 -0.00012206058 -5.1651135f-5 -5.2316824f-5 -0.00022294206 -1.7498338f-5 -6.6832574f-5 4.79513f-5 7.617698f-5 0.000107126856 -3.990855f-5 -9.667947f-5 -1.8703668f-5 1.7795313f-5 -3.865066f-5 -5.858509f-5 4.082475f-5 -4.345658f-5 -0.0001479024; 0.00015671342 0.00011593369 3.573354f-5 9.384737f-6 0.00015108759 2.1280528f-5 -3.8133407f-5 -4.4229728f-5 2.692922f-5 -0.00013637904 3.544943f-5 0.00010651964 0.000110070556 -7.7545104f-5 6.5137225f-5 5.7342397f-5 -4.3638884f-5 1.3848266f-5 4.1454714f-5 7.925176f-5 7.660025f-5 -0.00019993498 -6.98933f-5 0.000106639476 6.95927f-5 0.00012430256 7.0666865f-5 -1.7073491f-5 -9.8446835f-5 -0.0001025357 4.335825f-6 0.00019166672; -0.0001732263 7.298334f-5 -2.1027545f-5 -0.00017605867 -5.6743415f-6 2.8378226f-5 0.00016427132 -0.00026247487 0.00033318848 6.5579145f-5 -6.131681f-5 -1.2838271f-5 -0.0001258596 5.6244622f-5 -9.5335046f-5 -0.00016105559 -1.0383943f-5 8.131914f-5 0.00013746694 -3.4544108f-5 -2.08159f-5 5.05957f-5 1.8042745f-5 -2.4385221f-5 -2.689238f-5 -5.5477205f-5 5.8216152f-5 -0.000109980596 -9.905133f-5 -0.00012038238 3.0687565f-6 -0.00011114699; -7.036333f-5 8.678109f-5 2.5527273f-5 4.0151335f-5 9.326175f-5 3.5126985f-5 2.5920593f-7 -0.0001477559 0.00013994396 -5.2520067f-5 0.00018048167 -2.1655396f-5 2.4807201f-5 -3.2108346f-5 1.0994176f-5 -0.00016451656 -4.4062646f-5 2.154061f-5 -3.139748f-5 0.00012607245 -6.941258f-5 -4.0293166f-5 6.90361f-5 0.00022465832 -2.6056935f-6 0.00013417868 7.804854f-5 5.514754f-5 2.8794045f-6 -3.1663272f-5 -3.735181f-5 -9.954036f-5; 1.022892f-5 -6.448299f-5 -0.00014066426 8.0163976f-5 7.5732474f-5 2.9489489f-5 -7.4807205f-5 -5.698566f-5 -2.9339439f-5 -0.00013606611 -0.00014739849 -5.4692136f-5 7.607685f-5 -3.1953452f-5 -9.931016f-5 -6.403767f-5 0.00016295342 6.842537f-5 5.6956505f-5 0.00017460124 0.00016460204 -4.6016007f-6 -0.00024530047 -1.9600222f-5 -0.00011705401 -0.000127548 -8.831056f-5 -3.5284582f-5 0.00020077915 0.00014877856 5.3398195f-5 1.9033978f-5; -8.1479964f-5 0.00015065075 5.388123f-5 -4.038887f-5 4.760108f-5 -0.00010611966 -6.196215f-5 -2.9852592f-5 -7.457004f-5 5.5773573f-5 -0.00010900844 0.00014127578 9.336244f-5 -4.624992f-5 -7.029557f-5 5.8275513f-5 -9.641412f-6 -1.9633271f-5 -8.8899884f-5 4.7952212f-6 -3.925218f-5 -9.871096f-5 3.628659f-6 -6.656003f-5 -3.613817f-5 -4.516891f-5 6.869785f-5 3.319269f-5 -7.3369876f-5 -9.2004433f-7 3.545974f-5 -0.00014195331; 0.00011975481 -0.00017214511 -0.00027085352 -3.838662f-6 -4.483196f-5 0.00015148091 6.205813f-5 7.260404f-5 1.8663832f-5 -0.00011857501 -4.5813926f-5 0.00012582108 -0.0001323008 0.00010438747 9.779268f-5 -6.739874f-6 5.1547842f-5 -1.4675281f-5 1.0711752f-5 9.81369f-5 3.098993f-5 -6.138838f-5 -0.00014040747 -4.1670595f-5 -3.373663f-5 -1.600881f-5 -4.602829f-5 2.5250103f-5 0.00026016918 -5.0770497f-5 -0.00018628404 -2.384383f-6; -8.733175f-5 -2.5002553f-5 4.306396f-5 5.925154f-6 0.00012902381 1.7724797f-5 0.00016536802 -1.669564f-5 7.055523f-5 6.321581f-5 -6.415144f-5 -0.000250211 4.2493994f-5 -0.00014041163 1.8081682f-5 -0.00017742261 0.00015957524 -0.00016142758 -8.5948435f-5 -7.6482764f-5 -2.6299942f-5 -0.00010412529 -1.3330716f-5 0.0001125541 1.51088f-5 -0.00011306261 4.976699f-5 -1.0304333f-5 -0.00012647765 -9.29928f-6 -3.9780312f-5 -5.7594734f-5; -0.00017100819 1.2105732f-5 6.8457746f-5 3.2015374f-5 -1.4388199f-5 2.9092338f-5 -2.590788f-5 -0.00013806099 2.8870032f-5 9.141652f-5 -0.00011163715 -3.3016906f-5 -8.275127f-5 0.00016561814 -0.00021770061 6.7876324f-5 -9.4464114f-5 4.8596725f-5 0.00017088227 6.166271f-5 -0.00014094119 0.00010312514 4.472186f-6 1.1132412f-5 8.657137f-5 2.6038782f-5 2.497923f-5 6.573373f-5 0.0002254918 2.1776137f-5 0.0001786482 -4.9326227f-5; 4.567778f-5 -1.6480424f-5 0.00013995642 -0.00027588516 -3.6285262f-5 -0.00012927884 -0.0001762962 -8.023587f-5 7.1966104f-5 -6.476933f-5 0.00012292774 -0.00019989409 -0.00022472192 1.6717897f-5 -2.5023935f-5 -4.4514465f-5 -5.4643646f-5 -3.394325f-5 -1.7248947f-5 6.296457f-5 -6.225874f-5 -1.0178501f-5 -5.769465f-5 7.685066f-5 -8.438582f-5 0.00011126813 6.278124f-5 -3.0965846f-5 -1.1172117f-5 0.00015099054 -2.125105f-5 0.0001619129; -4.7056466f-5 0.00010958509 -9.91231f-6 6.2984465f-5 5.4546184f-5 -6.8054585f-5 -2.2093664f-6 6.9947266f-5 -2.801772f-5 -6.124044f-5 5.9248432f-5 6.5161994f-6 -3.4490477f-5 6.5756416f-5 5.983356f-5 0.00015670297 6.562369f-5 3.6298396f-5 2.5917338f-6 -6.20962f-5 -0.00017462802 -0.00011205958 -0.0001419291 -0.00013755949 -2.9821127f-5 -6.0488244f-5 7.276058f-5 -0.00010927237 -0.00017524176 0.00010117306 -0.00014609257 -0.00013354831; 9.701471f-6 -0.00021020457 -8.492709f-5 9.705451f-5 -8.8903376f-5 0.00019871962 -0.00012423377 -2.456979f-5 5.0564446f-5 2.128747f-5 -6.129259f-5 0.00013128867 -0.0002384069 7.253606f-5 0.000101914884 0.0002388846 -9.777445f-5 -5.4337164f-5 9.276706f-5 -2.9588255f-5 -2.353557f-5 -0.000189558 0.00012128255 2.7677133f-5 9.9276085f-6 0.00010896452 -8.608624f-5 4.0671188f-5 6.3170235f-5 -5.610703f-5 -0.0001584691 3.0928723f-7; 2.8117689f-5 0.00010102379 -0.00015527935 2.8370869f-5 -0.00012203356 0.00018187817 9.3108594f-7 4.1563344f-5 -1.2366614f-5 -4.320309f-5 -0.00017092461 -3.8372864f-6 2.8505283f-5 -4.563044f-5 -3.4902765f-5 -6.625029f-5 7.051068f-5 -2.6489653f-5 9.0381065f-5 0.00011449806 -6.004659f-5 0.00017742296 0.00011757528 -9.368483f-5 -0.00016357108 4.1729185f-5 -8.868677f-5 0.00012994872 -5.135045f-5 -4.1458175f-6 -6.5738386f-5 -0.0001571745; 7.3763855f-5 7.2021714f-5 0.00016850785 -8.880169f-5 0.00010385924 9.436527f-8 -9.269683f-5 0.000119656164 4.5442404f-5 2.9012914f-5 5.2649525f-6 -2.5329547f-5 -2.9863522f-5 -2.3064867f-5 -0.00012244273 0.00012623367 -2.1539094f-5 -9.8348784f-5 -3.0480003f-5 -2.2350112f-5 -3.472104f-5 -0.0001283905 -8.406765f-5 -4.4036224f-5 -4.0759667f-5 -1.8734545f-5 -9.986437f-6 -5.342516f-5 -2.200094f-5 0.00010005978 4.3087377f-5 -0.00022177344; 4.2569016f-5 5.7517904f-5 5.169262f-6 -0.00012657372 7.546552f-5 -3.003708f-5 -7.026017f-5 0.00014579915 6.3619955f-5 -7.878464f-5 0.00017665219 1.5169947f-5 6.212588f-5 -8.662644f-5 0.00012255202 -0.00017744473 8.372976f-5 0.0002037982 -3.710257f-5 -9.869019f-5 -0.0001738194 -6.240207f-5 -7.318418f-5 0.00015507065 -4.5888282f-6 8.256229f-5 -0.00015783629 2.770236f-5 -1.2132943f-5 -0.00018456898 -0.00018564233 -3.1433254f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[5.421233f-5 9.0343114f-5 -6.1137165f-5 1.1732396f-5 3.0139989f-5 -0.00015654134 3.8321417f-5 -0.00011828456 0.00017487114 1.4428792f-5 -0.0001880524 -5.14773f-5 -5.779915f-5 6.3310195f-5 -3.1628504f-5 -6.0035698f-5 0.00010111617 -0.00013916024 3.656961f-5 -0.00024806138 0.00014512321 -0.00017636962 2.4420944f-5 6.447041f-5 -3.209032f-5 9.8344035f-6 -0.00016778137 -0.00014129494 2.5165586f-5 9.9588724f-5 9.583881f-5 -4.7820107f-5; 0.0002066463 0.00012387565 1.7134224f-5 5.0449245f-5 -3.2951993f-5 7.038855f-5 -0.00012783046 5.3425723f-5 -4.4522487f-5 2.4328769f-5 -5.906376f-5 0.0001300035 1.5695308f-5 -0.000107534644 0.00011151003 2.1290754f-5 -0.00023819199 -2.1274402f-5 -0.00010831961 -0.00015910916 -0.00010208733 5.453083f-5 -0.00017767944 2.007757f-5 -5.768941f-5 -4.6838904f-5 -0.00011441517 -2.0624966f-5 4.495837f-5 0.00010818863 -3.067845f-6 6.9803435f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most DL frameworks, Lux defaults to using Float32, however, in this case we need Float64

julia
const params = ComponentArray(ps |> f64)

const nn_model = StatefulLuxLayer{true}(nn, nothing, st)
StatefulLuxLayer{true}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),  # 64 parameters
        layer_3 = Dense(32 => 32, cos),  # 1_056 parameters
        layer_4 = Dense(32 => 2),       # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p = params, saveat = tsteps, dt, adaptive = false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel = "Time", ylabel = "Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth = 2, alpha = 0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker = :circle, markersize = 12, alpha = 0.5, strokewidth = 2
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth = 2, alpha = 0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker = :circle, markersize = 12, alpha = 0.5, strokewidth = 2
    )

    axislegend(
        ax, [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"]; position = :lb
    )

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p = θ, saveat = tsteps, dt, adaptive = false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end
loss (generic function with 1 method)

Warmup the loss function

julia
loss(params)
0.0006864946400429274

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end
callback (generic function with 1 method)

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob, BFGS(; initial_stepnorm = 0.01, linesearch = LineSearches.BackTracking());
    callback, maxiters = 1000
)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [1.1684905075510235e-5; -4.756234193336002e-5; -9.834645607041006e-5; 9.27074142963352e-5; -5.583238453256071e-5; 5.513598080143146e-5; -5.3152118198318194e-5; 0.00016854524437804217; -0.00019818036525975926; 4.023786459582534e-5; 6.343576387734276e-6; 6.203808879937171e-5; 7.028735126360051e-5; -8.307965617857265e-5; 0.00011663161421886803; 2.353656782358725e-6; -7.852615817684807e-5; -1.4547541468323432e-5; 0.00014218158321448301; -8.962746505859295e-5; -5.7716213632383133e-5; -4.23954734287463e-5; 5.547955879589149e-5; 3.7869260268019525e-5; 8.198458817778692e-5; -2.160575422745688e-5; -0.00019508181139819516; 9.92121385933282e-7; -0.00014845223631679104; 0.00025220695533806657; 0.00012981895997633323; 0.00010813903645594471;;], bias = [1.865859080332062e-17, -1.2824401977179955e-17, -2.1914020048586422e-17, 1.3953906849459077e-16, 1.172737118145792e-17, -7.63723918364325e-18, -2.6090242252946762e-17, -7.754440293700479e-19, 6.5590527410027875e-18, 1.447365033888795e-17, 9.213948889934352e-19, 9.5054922366098e-17, 1.0030537726226057e-16, -2.855907520598836e-17, 7.342817000470606e-17, 2.8958011037053682e-18, 2.2448218901465697e-17, -2.5934926002178136e-17, 7.769523381962897e-17, 1.0152856126791147e-17, 8.858032523348367e-18, -2.5924201151734986e-17, 3.6167871105813494e-17, 3.858662067995849e-17, 1.0144088290169518e-16, -4.250749332739879e-17, -7.996053400570479e-17, 1.1704354945032432e-18, -3.529262225762163e-16, -1.9489141108862974e-16, 2.4562664924127276e-16, 2.0552684131197199e-16]), layer_3 = (weight = [2.0886512883486433e-5 -1.0384860816805394e-5 -9.92340958154373e-5 -0.00023337696095123765 -6.719626455461609e-5 0.00011228646124353629 -2.775116647254686e-5 -0.0001373581710330711 -9.93750602182476e-5 -1.628159972907548e-5 6.670039672743454e-5 -4.022519734382961e-6 -8.666834239276088e-5 2.8168140287642277e-5 7.492126840505988e-5 5.650303845947317e-6 4.486371633819371e-5 5.2468136144931645e-5 -7.762909283056463e-5 -2.2530695988199503e-5 -6.229308573098254e-7 4.899205831940874e-6 -1.5006419945086865e-5 0.0001162461828963268 -0.00012564019763968806 -0.00023816845552656864 0.00011265720038781751 0.00010839276163013575 4.479745419220503e-5 5.289979688232301e-5 0.00013163737255953567 9.073361232284652e-5; -5.991680779548303e-5 6.0145624165349133e-5 5.025701205308536e-5 6.925079026554707e-5 -4.2840771921285996e-5 -0.00011384374272015841 7.883465253111944e-5 0.0001348178713029887 0.00010448932255680441 0.00011217414533316205 -2.8544708466464066e-5 -0.0001297731858997882 9.199162844165251e-5 -1.2790864157390234e-5 -1.138069535077591e-5 6.376366321059935e-5 -0.00012432083064627816 5.937419803526311e-5 -5.765940268570297e-6 -6.628699963617137e-5 -0.00016994991789975167 -1.5245925961019946e-5 5.888010958155654e-5 2.233459156709306e-5 -2.5446962236268356e-5 -0.00012901169872780275 -1.0667107626756084e-5 -5.9644024868569674e-5 -8.53901282156176e-5 0.00018054211353013812 -2.7472945357980863e-5 9.249442621662254e-5; 0.0001503598211817801 0.0001356452300930741 4.519020125342726e-5 2.0254415848905024e-5 -7.291068195913513e-6 3.60687370081772e-5 2.0442668519249776e-5 -6.865478335686872e-6 -7.816060338913245e-5 0.00010038438971576973 7.814714239401601e-5 5.624461458836304e-5 2.5180171851101535e-5 -8.598907934644801e-6 0.00010461683608780564 8.417887092662052e-6 -0.0001445752243492432 5.4431009393453945e-5 6.919634664852642e-5 -0.00020710841296051334 -0.00021790916458513932 8.212513397637506e-5 8.130416585089308e-5 4.073712965663295e-5 -4.415628846179194e-6 -2.1176984101032647e-5 -5.716784313366453e-5 -9.65653127971678e-5 0.00014454641835952533 -0.00018500140528008698 0.00019190006763863127 0.0001110838818382758; -6.108892722390552e-5 -0.00010922487291377052 0.0002144846423006596 7.722397111952125e-5 -4.049471846069246e-5 -0.00014348791756769677 0.00012652662388877917 3.794887025129504e-8 -2.0887558831470487e-5 1.8649420225789308e-6 -7.487293060971381e-5 -1.2505532878137278e-5 1.7931248466041038e-5 4.052000066578443e-6 -6.35627164329433e-5 2.195034920892257e-5 -5.9501116821636666e-5 -2.8600184860848428e-5 0.00021861523254205922 -9.473946021033222e-6 -3.3900867320923554e-5 -3.7145125871617367e-5 -7.525988058755209e-5 0.0001179043376295828 -7.879682544721418e-5 8.108308810638296e-5 0.00010687970478915602 -1.4864807609270351e-5 -4.841013816171851e-5 0.00016364023594081663 -5.712422517813203e-5 3.23262704013821e-5; 4.649388761815362e-5 -7.074089984901344e-5 -2.900254406378375e-6 -0.0002108788630205913 2.350482469107461e-5 0.00010442061289420317 3.572175107117096e-7 3.2614449134320165e-5 -8.640345362736941e-5 -3.8098762009056087e-7 0.00011035078756019212 -7.69233392414448e-5 3.482900324950103e-5 -0.00013451565059160555 0.00010987181854641827 2.6516229869870457e-6 -1.8591033898087068e-5 -4.1373210427381715e-5 2.570288240064776e-6 -6.161641971932383e-5 0.00010754521468395354 -6.405911738809781e-6 2.69697757691328e-5 -7.406224353265174e-5 -6.598113572483639e-5 -0.00012123931987237414 -0.0001190668644479029 -2.7029769574555406e-5 -0.00016030452565547548 4.386883626214593e-6 2.5631181467073785e-5 2.1791420346453444e-5; 4.492085381021722e-5 4.681998788688125e-5 1.1738988263413851e-5 -5.809304539620764e-5 -0.0001278407025038573 9.50791773847928e-5 0.00015312687637011258 7.645887407664268e-5 -1.469520876169732e-5 2.1222994377201837e-5 -7.457674192304717e-5 -2.835765272361121e-5 0.00010177688554354477 -2.5318528883915437e-5 -0.0001796078458083272 8.017907848075974e-5 -0.00016124718504765656 0.0001866283829203527 0.00015553926377282398 -9.439069714106821e-6 6.759022389265396e-5 2.5062793696595594e-5 1.4377594283320411e-5 0.00010774798019516155 2.2931654797460087e-5 9.338569097413647e-5 8.111955421001706e-5 1.6179264179263384e-5 3.341745204482902e-5 4.6635389566192466e-5 -6.973446116358728e-5 9.25998802758362e-5; -1.947638547447107e-5 -0.0001073025944572053 0.00012286410283537943 -5.6912902706456846e-5 0.00010645224911401756 0.00020565370296369087 -1.302530135601965e-5 2.5446713792693948e-5 8.461595798848845e-5 -0.00016344983514157268 -5.857261775984487e-5 -8.961904886403383e-5 0.00010963173345205054 -1.737554085805179e-6 -6.905648687846808e-5 -3.482583478783314e-5 0.0001386207565735894 -1.085391185936585e-5 6.383733644157448e-5 -0.0001058577275182262 -3.471289009778993e-5 0.00014418891596848924 -5.8280455319827744e-5 4.3083846052259746e-5 5.999103635774749e-6 -2.0595273141433355e-5 -1.062650544213351e-5 -8.515586732412744e-5 8.221745311289728e-5 6.019300936144128e-5 0.00017254352651776726 1.4209661188728845e-5; -9.793854427700486e-5 9.850900802119212e-5 -0.00013970247549879327 1.8983167910703937e-6 -0.00011692821884937024 -3.332374421008306e-5 -9.675670320577376e-5 -0.00010422886429296533 3.0020171616347666e-5 -0.0001510845179291395 1.0252138745252955e-5 8.694751137225564e-5 1.057488567310595e-5 -1.3146738856464103e-5 -1.1407176534714614e-5 0.00010160758371055494 6.872133941213422e-5 -4.686092892372599e-5 -4.5148630524643275e-5 -2.6895604823963424e-5 -0.00010194852096903494 -5.178052400880945e-5 9.962236777935736e-5 6.436260424377854e-5 3.189132963596698e-5 -5.504435124750783e-5 -4.3338717878210027e-5 1.902993950350578e-5 6.134219416707631e-5 5.010004293582999e-5 -9.705215074064655e-5 -2.9735669377272245e-5; -9.650537241340354e-5 4.112797695245397e-6 -2.8278733328777014e-6 -2.431171115984239e-5 9.387905164532085e-5 3.574471419471026e-5 0.00013382391342803515 0.0001250156391403382 -8.642603926212739e-5 4.243104275575363e-5 -1.4071650742976495e-6 -0.00016365546172108684 -5.938529203502503e-5 3.810248825959145e-5 -0.00022336227549328311 -4.7622549674443324e-5 -2.720022813389891e-5 -5.968394099126298e-5 5.070071064313965e-5 5.001967373447855e-5 -2.4605046848033288e-5 -8.268117663773229e-5 -0.00016701916602995532 2.0682723414767193e-5 6.074094279812839e-5 3.6909695050872585e-6 -2.68332706685541e-5 -0.000201565398230187 -0.00030435894869820817 5.919307376525539e-5 -0.00011349754923809667 -6.613780270652005e-5; 6.894806892596058e-5 6.588190035181052e-5 -5.356548399795638e-5 6.813991146188868e-6 -0.0001754253648898196 -0.00012160885353587615 -9.635873402949475e-5 9.650051194333633e-6 4.910954839652447e-5 -4.9895698342156525e-5 5.28955387998824e-6 0.00010840565255129686 -3.879746191451318e-5 0.00015994511659009442 -5.4892265782838035e-5 -0.00016180651846589653 -7.48347852124456e-5 -1.653216031918893e-5 -6.866899321103031e-5 0.00015062660849871203 0.0002472924347755868 6.797457490101375e-5 0.00010658429846152617 -0.00012163789915867242 -6.801496465705278e-5 -1.8546716549467546e-5 -0.00012419606037219046 1.4764196946950634e-6 6.784557944847759e-5 -0.00012581839519333295 2.6951830929465293e-5 -6.37297444263974e-6; -0.0001365522551890667 3.752281156319935e-5 0.0001878226099693941 4.687816836472171e-6 -0.00011081425362082391 2.0609691655941256e-5 -0.0001777137152925339 -0.00015084716904258615 1.973011205389809e-5 9.00768278557602e-5 -5.742050814607364e-5 1.0713333377303018e-5 0.0001615104771377313 -0.00010087448061259666 -8.372974878458525e-5 4.6913147908149314e-5 -1.377030855966466e-5 7.275336212895626e-5 1.3486373628061228e-5 7.704395067829435e-5 -8.012138312498299e-5 0.00013842942245954464 -0.0001318853832157534 -2.82337390484959e-5 4.3390489391263425e-5 -3.769428110846503e-5 0.00015222909098209532 9.14778353222441e-5 -5.630154953852537e-5 3.9359685270559265e-5 -8.705833932200674e-5 6.76003907033966e-5; 6.163793469353238e-5 -2.882821908530274e-5 -1.823797000940209e-5 -3.028204090912104e-5 -9.3756940343487e-5 -9.5532004790916e-5 3.4474325646110984e-5 0.00016388879331100207 5.1014916523508566e-5 1.2704832358973281e-5 4.578007966086862e-5 0.00011647980309829557 -9.6428424596857e-5 2.7405081625723717e-5 -0.00010628308345291703 -2.8459731849910345e-5 0.0001699479343257135 6.21801463308989e-5 7.48859110059383e-5 -6.6654407349841256e-6 -9.006770241500344e-5 0.00012575961764708682 1.3626136849448162e-5 6.269043106625395e-5 8.060028076909814e-5 3.700148036626745e-5 1.9224739004313867e-6 -0.00016597113980639132 7.05973233449432e-5 7.630718256306523e-7 -5.3023255616035985e-5 -6.913937076421318e-5; -0.00017198994011053003 8.337034148985306e-6 -3.815789297006924e-5 -0.0001533822605888631 0.00012462983885413916 0.000205340770651651 9.806430160074406e-6 -9.40057461885031e-5 9.239228255681534e-5 4.1281987239903874e-5 7.536222709675202e-5 7.475320034061525e-5 -9.117408172826292e-5 -0.00014074677438841166 -0.00013111986755759265 -0.00016392932867531783 -0.00012497073749911952 0.00015171972789901193 -0.00010060925252402793 0.000126607167647286 7.676034600806378e-5 -9.605349169944492e-5 0.00011591984820186111 0.00011354585602746258 2.5578400698555968e-5 6.363610463083614e-5 3.7376089483572166e-5 -0.0001365918678966642 -0.0001245022676921691 -2.5330825348655488e-5 3.834969628585766e-5 2.0366004708173285e-5; -0.00011352584235774369 -0.0002940098636148853 -0.0001418764875101219 1.8243547196656023e-5 6.512996855779813e-6 -1.6689473091356955e-5 -2.612628097758883e-5 2.614083157334066e-5 -9.303419913100498e-5 -5.215953181086932e-5 2.1187152264788494e-5 -6.927706247133837e-5 2.360306304999552e-5 7.439097705592026e-5 2.337819230397893e-5 3.240615139893227e-5 -0.00015396410865414393 -0.00013277998343747119 0.0002172155803005939 6.0796741558627227e-5 2.864701698363081e-5 0.00019429887495298062 -8.612358190057569e-5 8.098145942730921e-5 -8.466803385581282e-5 -7.629740841858078e-5 -3.691959475451672e-5 -0.00018951811919160218 0.00010151827907485867 1.2802664965919651e-6 -2.318237941416441e-5 -1.911677984516569e-5; -9.90671865105447e-5 0.000167567228800729 0.00010608356069551972 -0.00013052917665790605 2.666418231511051e-5 0.00010259347300289226 1.8157976683690418e-5 -8.787288561451528e-5 -5.1715428980412074e-5 5.688201463326802e-5 -0.00012741393174964984 -5.509288121120026e-5 -6.980637726828384e-5 1.6115422568513368e-5 -8.277318147465299e-5 -1.6466757642226783e-5 0.00019177553304016692 -0.00011531533118559089 6.17482332932994e-5 2.269362994961305e-6 -1.9137841717926455e-5 -1.3781926554108548e-5 -9.450592642130524e-6 9.933112278358999e-5 3.363808764388901e-5 -0.00027377181277133566 -2.1997329440080033e-5 -6.243743291579234e-5 7.43549208037468e-6 0.0001000117958943698 1.3472243609255124e-5 2.564440592378152e-5; -9.716315684228733e-5 -9.369601751997697e-5 -2.048959222083595e-6 -1.075668818841186e-5 0.00012251715385035034 -0.00022005083586874993 -0.0001222987673277372 -3.517768413032444e-5 5.0234323917243404e-5 -5.631300586130092e-5 -3.709966471389464e-5 6.337101809593262e-5 0.00017728950048011123 -0.0002194761661844671 -2.725946864258772e-5 8.609894144727321e-5 8.553205703763599e-5 -0.00017367523536406058 -0.00010516272299311966 6.85526495185336e-5 3.121271152893432e-5 4.540472000809649e-6 0.00017897822113853304 -2.9168926586714285e-7 2.4275167798333063e-5 -6.499044265654632e-6 0.0002657036632267367 0.00017111294006143168 2.4510881540189168e-5 -3.9071449227348265e-5 7.124858494061291e-6 1.59373896717442e-5; 9.474625907973601e-5 4.6134174723597314e-5 -0.00023721686962159132 4.783368926513329e-5 9.181839190791284e-5 3.033065472340608e-5 -5.032948543299037e-5 -4.219142506697896e-5 -0.00011894097629337446 9.382366037829849e-5 6.245944454935603e-5 5.049733285369926e-7 2.5731347292496144e-5 2.079466462095288e-5 7.979854630905967e-5 0.00020075181741635538 0.00012970171296405232 -5.707727226356771e-6 -1.1678750027908218e-5 -3.7865660900535034e-5 -2.1957076072391167e-5 7.475949455187708e-5 -2.0022457150440457e-5 5.815317267337901e-5 2.0471151898564213e-5 0.00011795174718504022 -0.00013098278073229587 -4.8749154715144596e-5 0.00011614023746590071 -2.2572127851352744e-6 0.00020199094210186487 8.719528110700381e-6; -2.6468726825355355e-7 0.0001228136181244584 7.151638526656065e-5 2.226558097126183e-5 6.343214152101092e-5 0.00017086931000183923 -0.0001961727746883444 -2.7476166745043323e-5 -6.767106135419017e-5 -6.887316596741347e-5 -1.7107175902237174e-5 -8.900802533654758e-5 -9.452735705884758e-6 -0.00015237322697695294 -0.00012206368303205087 -5.165424119961171e-5 -5.2319929475663494e-5 -0.00022294517004676795 -1.750144367442459e-5 -6.683567955670714e-5 4.79481943121614e-5 7.617387125454371e-5 0.00010712375040119173 -3.991165687959219e-5 -9.66825744607483e-5 -1.870677335584212e-5 1.779220763375945e-5 -3.865376749938078e-5 -5.8588196064096804e-5 4.082164484966783e-5 -4.345968666066285e-5 -0.0001479055125154501; 0.00015671655901217373 0.00011593683202493258 3.5736683365865256e-5 9.387878798333276e-6 0.00015109073037722316 2.128367022935476e-5 -3.813026442101028e-5 -4.422658557914746e-5 2.6932361742078514e-5 -0.00013637589390880111 3.545257177295644e-5 0.00010652278333411642 0.00011007369803239572 -7.754196221965582e-5 6.514036694981189e-5 5.7345538970302044e-5 -4.363574144106502e-5 1.385140776522223e-5 4.145785614937829e-5 7.9254902538123e-5 7.660339078840946e-5 -0.00019993183477806558 -6.989015773895023e-5 0.00010664261835602452 6.95958432506098e-5 0.00012430570395638674 7.067000742507494e-5 -1.7070348895750748e-5 -9.844369296942737e-5 -0.00010253255559395481 4.33896720906883e-6 0.00019166986648898795; -0.0001732285721398425 7.29810732711587e-5 -2.1029812460967804e-5 -0.00017606094236797035 -5.676609208228981e-6 2.837595881757066e-5 0.0001642690501004942 -0.00026247714170564673 0.0003331862123279656 6.55768769040248e-5 -6.131907624038887e-5 -1.284053844827893e-5 -0.0001258618697927764 5.624235442993724e-5 -9.533731402441801e-5 -0.00016105785955986958 -1.0386210606871449e-5 8.131686869929539e-5 0.00013746466963241428 -3.454637529174221e-5 -2.0818167586901746e-5 5.0593431905518486e-5 1.8040477397527763e-5 -2.4387488638958034e-5 -2.6894647208596443e-5 -5.547947269477867e-5 5.821388428488763e-5 -0.0001099828635513319 -9.905359668734725e-5 -0.00012038464531609734 3.066488870193239e-6 -0.0001111492577645377; -7.036141136819351e-5 8.678300405694186e-5 2.552918890972643e-5 4.01532507148068e-5 9.326366312410643e-5 3.512890002997558e-5 2.611213780610151e-7 -0.00014775398262470703 0.00013994587975859227 -5.251815122943316e-5 0.00018048358656773498 -2.165348082240755e-5 2.4809116488484087e-5 -3.2106430759143945e-5 1.099609112673683e-5 -0.00016451464575407873 -4.406073089381907e-5 2.1542526204885473e-5 -3.1395566062267664e-5 0.00012607436843057752 -6.941066198675014e-5 -4.029125006678471e-5 6.903801300171726e-5 0.00022466023261034928 -2.60377800810936e-6 0.00013418060012003536 7.805045793603767e-5 5.514945410499504e-5 2.8813199641498624e-6 -3.166135679394064e-5 -3.734989326955707e-5 -9.953844523449067e-5; 1.0228152439215116e-5 -6.448375958135959e-5 -0.00014066502213828686 8.016320911118313e-5 7.573170709072958e-5 2.9488721827495535e-5 -7.480797236204741e-5 -5.6986427318779224e-5 -2.934020582179678e-5 -0.00013606687886059913 -0.0001473992556042663 -5.4692903597621346e-5 7.607608544056694e-5 -3.195421910496919e-5 -9.931092378102564e-5 -6.403843459558407e-5 0.00016295265644438052 6.842460110199908e-5 5.695573830743166e-5 0.00017460047505445312 0.00016460127202443196 -4.602367846398066e-6 -0.0002453012379759609 -1.9600989128245753e-5 -0.00011705477978686367 -0.00012754876976252256 -8.831132751075256e-5 -3.5285348951599704e-5 0.0002007783810200531 0.0001487777958915104 5.339742759823559e-5 1.903321075368302e-5; -8.148181004690958e-5 0.00015064890831234973 5.387938508756077e-5 -4.0390715729524196e-5 4.75992332223446e-5 -0.00010612150581328977 -6.196399553920002e-5 -2.9854437888201392e-5 -7.457188403239949e-5 5.577172705778919e-5 -0.00010901028654080843 0.00014127393878988193 9.336059374726643e-5 -4.6251765160202866e-5 -7.029741900514958e-5 5.82736669468318e-5 -9.643258271179031e-6 -1.9635117131485776e-5 -8.890172975843522e-5 4.793375320507765e-6 -3.9254025613272734e-5 -9.871280562839392e-5 3.6268130781441328e-6 -6.656187688241657e-5 -3.614001761780679e-5 -4.5170754309594654e-5 6.869600692990619e-5 3.3190843533873264e-5 -7.337172209990786e-5 -9.218902380513637e-7 3.5457893683268226e-5 -0.0001419551532297235; 0.00011975422010316558 -0.00017214570336484262 -0.00027085410922274696 -3.839253641964402e-6 -4.483255044881151e-5 0.0001514803202359668 6.205753514102758e-5 7.260344635022465e-5 1.866324007886331e-5 -0.0001185755995914612 -4.581451732640109e-5 0.00012582049155633083 -0.00013230139531635542 0.0001043868810290209 9.779209129072101e-5 -6.740465797814099e-6 5.1547250658293946e-5 -1.4675872554401387e-5 1.0711160465399413e-5 9.813630956948933e-5 3.098933800963792e-5 -6.138897380999212e-5 -0.00014040806356326572 -4.1671186883216e-5 -3.373722076623527e-5 -1.600940097098186e-5 -4.6028881589622844e-5 2.5249511489780125e-5 0.0002601685896285577 -5.077108882806857e-5 -0.0001862846352390256 -2.384974569686673e-6; -8.73342923968119e-5 -2.5005097990180588e-5 4.3061414549924776e-5 5.922609039884205e-6 0.00012902126946402764 1.772225187425544e-5 0.00016536547810221111 -1.669818447577017e-5 7.055268543458798e-5 6.321326691806242e-5 -6.415398638081013e-5 -0.0002502135303019535 4.249144877223498e-5 -0.0001404141785259322 1.80791375952399e-5 -0.00017742515591988547 0.00015957269534385134 -0.00016143012725832132 -8.595098005922906e-5 -7.648530844915201e-5 -2.6302486725290568e-5 -0.00010412783557865828 -1.3333260534528653e-5 0.00011255155732771294 1.510625498314969e-5 -0.00011306515625063502 4.9764445071696925e-5 -1.0306877806109606e-5 -0.0001264801958258102 -9.301824314958028e-6 -3.978285712389018e-5 -5.75972790046519e-5; -0.0001710065380486901 1.2107378997593525e-5 6.845939332355614e-5 3.201702110322119e-5 -1.4386551820512515e-5 2.9093985167259732e-5 -2.5906233541041116e-5 -0.00013805934305334672 2.8871679196275348e-5 9.141816825808327e-5 -0.00011163550244441803 -3.301525885968288e-5 -8.27496232980198e-5 0.00016561978826734253 -0.00021769896285173823 6.787797155061191e-5 -9.446246714129055e-5 4.8598371895192225e-5 0.00017088391449736665 6.166435467979828e-5 -0.00014093954542682298 0.00010312678476997943 4.473833077776842e-6 1.113405868613472e-5 8.65730134953468e-5 2.6040429465634653e-5 2.4980877232996047e-5 6.573537665621764e-5 0.0002254934542995519 2.1777784231479783e-5 0.00017864984964920623 -4.9324579786353345e-5; 4.567550580013707e-5 -1.6482699964578773e-5 0.00013995414495348871 -0.0002758874382090583 -3.628753826723981e-5 -0.00012928111233868252 -0.00017629848140877303 -8.023814577261512e-5 7.19638280605367e-5 -6.477160629237727e-5 0.00012292546464941619 -0.00019989636190820975 -0.00022472419938657844 1.6715621069858e-5 -2.5026211249667665e-5 -4.451674091633255e-5 -5.4645921653167735e-5 -3.394552756453695e-5 -1.7251222987563525e-5 6.296229706259262e-5 -6.226101528972916e-5 -1.018077661519583e-5 -5.769692713165274e-5 7.684838213372304e-5 -8.438809733724499e-5 0.00011126585171205178 6.277896475861043e-5 -3.096812203530069e-5 -1.1174393210404765e-5 0.0001509882656426752 -2.1253325274470148e-5 0.00016191062762282884; -4.705865676252193e-5 0.00010958289904200603 -9.91450024039752e-6 6.298227499618782e-5 5.454399406707108e-5 -6.805677576076983e-5 -2.2115566916094667e-6 6.994507539467101e-5 -2.8019909493997723e-5 -6.124263045231726e-5 5.924624174630241e-5 6.514009057437196e-6 -3.44926668712185e-5 6.575422567229253e-5 5.9831370619665976e-5 0.0001567007781905771 6.562149765349407e-5 3.6296205236874356e-5 2.589543512024928e-6 -6.209839220710483e-5 -0.0001746302080372607 -0.00011206176741135209 -0.00014193128326761347 -0.00013756167781858633 -2.9823317434336197e-5 -6.04904346782062e-5 7.275838988982051e-5 -0.0001092745590521305 -0.00017524394961391883 0.00010117086613003047 -0.00014609476083742205 -0.00013355050059354057; 9.700950046073751e-6 -0.00021020508995558647 -8.49276080988293e-5 9.705398713425329e-5 -8.89038971071088e-5 0.00019871909478400335 -0.00012423429563920518 -2.4570310982983302e-5 5.056392490236205e-5 2.128694848227723e-5 -6.129310770472132e-5 0.00013128815351240783 -0.0002384074147120072 7.253553792555384e-5 0.00010191436320031675 0.00023888408338837707 -9.777497018500676e-5 -5.4337685127273076e-5 9.27665417986063e-5 -2.9588775418111538e-5 -2.3536089996802907e-5 -0.00018955852703935987 0.00012128203104669662 2.7676612604529923e-5 9.927087718218392e-6 0.00010896400215290873 -8.608676273421287e-5 4.0670667091155504e-5 6.316971428180867e-5 -5.6107550902976e-5 -0.0001584696261705764 3.0876643212138655e-7; 2.811689684365204e-5 0.00010102299803891067 -0.00015528014681278688 2.837007652176556e-5 -0.00012203435379290153 0.00018187738261674403 9.302937806519298e-7 4.15625518295035e-5 -1.2367406255557539e-5 -4.320388232683931e-5 -0.0001709254056399243 -3.838078540343501e-6 2.8504490743739638e-5 -4.5631232736399716e-5 -3.490355719206714e-5 -6.625108561222652e-5 7.05098873518868e-5 -2.6490445619334648e-5 9.038027293770885e-5 0.0001144972711850423 -6.0047382233665634e-5 0.00017742216825042693 0.00011757448466285002 -9.368562221736821e-5 -0.00016357187176565676 4.1728392731396365e-5 -8.868756320536293e-5 0.00012994792724982629 -5.1351241366701335e-5 -4.146609611241952e-6 -6.573917833831867e-5 -0.00015717528638872227; 7.376263430893607e-5 7.20204935697121e-5 0.00016850662552139354 -8.880290795962736e-5 0.00010385801771741667 9.31448348664101e-8 -9.269804821263705e-5 0.00011965494388102927 4.544118393649223e-5 2.901169329385518e-5 5.263732071456693e-6 -2.5330767932186292e-5 -2.9864742720686786e-5 -2.3066087925993013e-5 -0.00012244394554903007 0.0001262324532199855 -2.154031415733083e-5 -9.835000481137816e-5 -3.048122369539407e-5 -2.2351332410773857e-5 -3.472226115439705e-5 -0.00012839172075649474 -8.406887162157962e-5 -4.40374441331508e-5 -4.076088760660513e-5 -1.8735765372126238e-5 -9.98765779688196e-6 -5.342637985626055e-5 -2.200216102385304e-5 0.00010005855630705973 4.308615654560655e-5 -0.00022177466406167768; 4.2568029629012036e-5 5.7516917466830626e-5 5.168275537121564e-6 -0.00012657470911234213 7.546453302767359e-5 -3.0038065643918293e-5 -7.026115458883125e-5 0.00014579815938179494 6.361896844157486e-5 -7.878562287386324e-5 0.0001766512028085453 1.5168960588756219e-5 6.212489420123686e-5 -8.662742443574241e-5 0.00012255103824701733 -0.00017744571637321138 8.372877668863491e-5 0.0002037972185904546 -3.710355708915889e-5 -9.869117316282897e-5 -0.00017382038318045204 -6.240305488189496e-5 -7.31851655031159e-5 0.00015506966478698127 -4.589814535743823e-6 8.256130563369165e-5 -0.00015783727253744883 2.7701372845160204e-5 -1.2133929386782432e-5 -0.00018456997027111848 -0.00018564331953837792 -3.143424005555414e-5], bias = [-7.077877976144833e-10, 2.2658209375401335e-10, 1.895150205598621e-9, 7.914322432827598e-10, -2.3155219693020318e-9, 2.936843847972799e-9, 1.6422780131962876e-9, -1.8919672780179916e-9, -4.040986269139523e-9, -8.626586398224321e-10, 7.969045828037919e-11, 1.1949920341182442e-9, -6.939553947035184e-10, -2.619094997934088e-9, -7.526063178898716e-10, 4.7590263591424e-10, 3.0245668743182823e-9, -3.1057633525056182e-9, 3.142212270954802e-9, -2.267671143843129e-9, 1.9154449008983554e-9, -7.671245999165557e-10, -1.845909366678581e-9, -5.91645696691385e-10, -2.544760613005634e-9, 1.6470855792651397e-9, -2.275799669656608e-9, -2.1903341918569904e-9, -5.207983738838187e-10, -7.921548115607132e-10, -1.2204345111350288e-9, -9.86302791740207e-10]), layer_4 = (weight = [-0.0006213388552397448 -0.0005852080835998407 -0.0007366882718395276 -0.0006638187867284822 -0.0006454110781190602 -0.0008320923085714789 -0.0006372297159005324 -0.000793835663527801 -0.0005006796816182057 -0.0006611223876407968 -0.0008636036052724226 -0.0007270284598566308 -0.0007333503348887806 -0.0006122408381893225 -0.000707179688117624 -0.0007355868907320488 -0.0005744348120905261 -0.0008147111875647709 -0.0006389813482977939 -0.0009236124373254438 -0.0005304279035905962 -0.0008519208027745623 -0.0006511301721029659 -0.0006110807776763492 -0.0007076413562766184 -0.0006657167280090714 -0.0008433324337438541 -0.0008168460077245483 -0.000650385606198951 -0.0005759624592717979 -0.0005797123508444127 -0.0007233712809003027; 0.00044889610096692087 0.0003661254514817301 0.0002593839904216043 0.0002926990395906213 0.0002092977591556634 0.0003126382699822454 0.00011441932108943487 0.0002956754894445875 0.00019772717694069324 0.00026657856173489195 0.00018318603866429416 0.00037225327933478015 0.00025794510364980666 0.00013471509629861797 0.0003537598258135786 0.00026354055202984153 4.05773223403231e-6 0.00022097530775401614 0.00013393010290132143 8.314059367308108e-5 0.000140162437616706 0.0002967806225719931 6.457033239081356e-5 0.000262327366966068 0.00018456033273684817 0.00019541087138822585 0.00012783458296557605 0.0002216247887118604 0.00028720816860510476 0.00035043842363551047 0.00023918194190362615 0.00031205322541316504], bias = [-0.0006755511985572984, 0.00024224979972483265]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel = "Iteration", ylabel = "Loss")

    lines!(ax, losses; linewidth = 4, alpha = 0.75)
    scatter!(ax, 1:length(losses), losses; marker = :circle, markersize = 12, strokewidth = 2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p = res.u, saveat = tsteps, dt, adaptive = false))
waveform_nn_trained = first(
    compute_waveform(
        dt_data, soln_nn, mass_ratio, ode_model_params
    )
)

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel = "Time", ylabel = "Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth = 2, alpha = 0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker = :circle, alpha = 0.5, strokewidth = 2, markersize = 12
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth = 2, alpha = 0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker = :circle, alpha = 0.5, strokewidth = 2, markersize = 12
    )

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth = 2, alpha = 0.75)
    s3 = scatter!(
        ax, tsteps, waveform_nn_trained; marker = :circle,
        alpha = 0.5, strokewidth = 2, markersize = 12
    )

    axislegend(
        ax, [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position = :lb
    )

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 128 default, 0 interactive, 64 GC (on 128 virtual cores)
Environment:
  JULIA_CPU_THREADS = 128
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 128
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.