Training a Neural ODE to Model Gravitational Waveforms
This code is adapted from Astroinformatics/ScientificMachineLearning
The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl
Package Imports
using Lux, ComponentArrays, LineSearches, OrdinaryDiffEqLowOrderRK, Optimization,
OptimizationOptimJL, Printf, Random, SciMLSensitivity
using CairoMakie
Precompiling Lux...
3275.3 ms ✓ KernelAbstractions
772.8 ms ✓ KernelAbstractions → LinearAlgebraExt
709.7 ms ✓ KernelAbstractions → EnzymeExt
5368.2 ms ✓ NNlib
824.9 ms ✓ NNlib → NNlibEnzymeCoreExt
868.1 ms ✓ NNlib → NNlibSpecialFunctionsExt
922.4 ms ✓ NNlib → NNlibForwardDiffExt
6214.3 ms ✓ LuxLib
9153.4 ms ✓ Lux
9 dependencies successfully precompiled in 26 seconds. 101 already precompiled.
Precompiling ComponentArrays...
872.3 ms ✓ ComponentArrays
1 dependency successfully precompiled in 1 seconds. 45 already precompiled.
Precompiling MLDataDevicesComponentArraysExt...
499.7 ms ✓ MLDataDevices → MLDataDevicesComponentArraysExt
1 dependency successfully precompiled in 1 seconds. 48 already precompiled.
Precompiling LuxComponentArraysExt...
494.7 ms ✓ ComponentArrays → ComponentArraysOptimisersExt
1390.6 ms ✓ Lux → LuxComponentArraysExt
2069.3 ms ✓ ComponentArrays → ComponentArraysKernelAbstractionsExt
3 dependencies successfully precompiled in 2 seconds. 112 already precompiled.
Precompiling OrdinaryDiffEqLowOrderRK...
318.0 ms ✓ FastPower
329.0 ms ✓ MuladdMacro
492.4 ms ✓ TruncatedStacktraces
689.9 ms ✓ FastBroadcast
745.1 ms ✓ RecursiveArrayTools → RecursiveArrayToolsFastBroadcastExt
5493.7 ms ✓ DiffEqBase
3950.3 ms ✓ OrdinaryDiffEqCore
1241.6 ms ✓ OrdinaryDiffEqCore → OrdinaryDiffEqCoreEnzymeCoreExt
3832.2 ms ✓ OrdinaryDiffEqLowOrderRK
9 dependencies successfully precompiled in 16 seconds. 91 already precompiled.
Precompiling FastPowerForwardDiffExt...
647.2 ms ✓ FastPower → FastPowerForwardDiffExt
1 dependency successfully precompiled in 1 seconds. 27 already precompiled.
Precompiling ComponentArraysRecursiveArrayToolsExt...
675.7 ms ✓ ComponentArrays → ComponentArraysRecursiveArrayToolsExt
1 dependency successfully precompiled in 1 seconds. 69 already precompiled.
Precompiling ComponentArraysSciMLBaseExt...
1039.9 ms ✓ ComponentArrays → ComponentArraysSciMLBaseExt
1 dependency successfully precompiled in 1 seconds. 89 already precompiled.
Precompiling DiffEqBaseForwardDiffExt...
1506.7 ms ✓ DiffEqBase → DiffEqBaseForwardDiffExt
1 dependency successfully precompiled in 2 seconds. 109 already precompiled.
Precompiling DiffEqBaseChainRulesCoreExt...
1293.3 ms ✓ DiffEqBase → DiffEqBaseChainRulesCoreExt
1 dependency successfully precompiled in 2 seconds. 98 already precompiled.
Precompiling SparseArraysExt...
928.6 ms ✓ KernelAbstractions → SparseArraysExt
1 dependency successfully precompiled in 1 seconds. 27 already precompiled.
Precompiling DiffEqBaseSparseArraysExt...
1342.7 ms ✓ DiffEqBase → DiffEqBaseSparseArraysExt
1 dependency successfully precompiled in 2 seconds. 101 already precompiled.
Precompiling SparseConnectivityTracerNNlibExt...
1672.6 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerNNlibExt
1 dependency successfully precompiled in 2 seconds. 47 already precompiled.
Precompiling SciMLSensitivity...
386.4 ms ✓ StructIO
390.0 ms ✓ HashArrayMappedTries
417.4 ms ✓ PoissonRandom
755.3 ms ✓ ResettableStacks
705.9 ms ✓ StructArrays → StructArraysGPUArraysCoreExt
763.7 ms ✓ PreallocationTools
934.4 ms ✓ Cassette
372.1 ms ✓ ScopedValues
511.9 ms ✓ FunctionProperties
1419.7 ms ✓ LLVMExtra_jll
1446.4 ms ✓ Enzyme_jll
1801.2 ms ✓ TimerOutputs
1860.9 ms ✓ IRTools
1845.4 ms ✓ DiffEqBase → DiffEqBaseDistributionsExt
1927.6 ms ✓ ObjectFile
2747.9 ms ✓ SciMLJacobianOperators
4126.6 ms ✓ DiffEqCallbacks
5006.6 ms ✓ Tracker
3752.1 ms ✓ DiffEqNoiseProcess
1101.4 ms ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
1113.2 ms ✓ FastPower → FastPowerTrackerExt
6293.6 ms ✓ Krylov
1170.4 ms ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
1268.5 ms ✓ ArrayInterface → ArrayInterfaceTrackerExt
1374.9 ms ✓ Tracker → TrackerPDMatsExt
5784.3 ms ✓ LLVM
2418.1 ms ✓ DiffEqBase → DiffEqBaseTrackerExt
1736.8 ms ✓ UnsafeAtomics → UnsafeAtomicsLLVM
11872.0 ms ✓ ArrayLayouts
801.8 ms ✓ ArrayLayouts → ArrayLayoutsSparseArraysExt
4552.6 ms ✓ GPUArrays
14300.6 ms ✓ ReverseDiff
2409.4 ms ✓ LazyArrays
1291.1 ms ✓ LazyArrays → LazyArraysStaticArraysExt
3305.5 ms ✓ FastPower → FastPowerReverseDiffExt
3300.7 ms ✓ PreallocationTools → PreallocationToolsReverseDiffExt
3374.5 ms ✓ ArrayInterface → ArrayInterfaceReverseDiffExt
3473.6 ms ✓ DifferentiationInterface → DifferentiationInterfaceReverseDiffExt
4709.3 ms ✓ DiffEqNoiseProcess → DiffEqNoiseProcessReverseDiffExt
4796.2 ms ✓ DiffEqBase → DiffEqBaseReverseDiffExt
15627.1 ms ✓ GPUCompiler
14445.6 ms ✓ LinearSolve
1657.9 ms ✓ LinearSolve → LinearSolveEnzymeExt
3362.5 ms ✓ LinearSolve → LinearSolveKernelAbstractionsExt
3880.5 ms ✓ LinearSolve → LinearSolveSparseArraysExt
24660.1 ms ✓ Zygote
1510.9 ms ✓ DifferentiationInterface → DifferentiationInterfaceZygoteExt
1828.9 ms ✓ Zygote → ZygoteTrackerExt
3141.1 ms ✓ RecursiveArrayTools → RecursiveArrayToolsZygoteExt
3374.9 ms ✓ SciMLBase → SciMLBaseZygoteExt
5287.3 ms ✓ RecursiveArrayTools → RecursiveArrayToolsReverseDiffExt
186763.5 ms ✓ Enzyme
5511.3 ms ✓ FastPower → FastPowerEnzymeExt
5533.3 ms ✓ Enzyme → EnzymeGPUArraysCoreExt
5631.4 ms ✓ DifferentiationInterface → DifferentiationInterfaceEnzymeExt
5630.1 ms ✓ QuadGK → QuadGKEnzymeExt
5691.5 ms ✓ Enzyme → EnzymeLogExpFunctionsExt
5973.7 ms ✓ Enzyme → EnzymeSpecialFunctionsExt
8150.6 ms ✓ Enzyme → EnzymeStaticArraysExt
10539.1 ms ✓ Enzyme → EnzymeChainRulesCoreExt
7686.6 ms ✓ DiffEqBase → DiffEqBaseEnzymeExt
20577.9 ms ✓ SciMLSensitivity
62 dependencies successfully precompiled in 241 seconds. 214 already precompiled.
Precompiling PreallocationToolsSparseConnectivityTracerExt...
1050.4 ms ✓ PreallocationTools → PreallocationToolsSparseConnectivityTracerExt
1 dependency successfully precompiled in 1 seconds. 44 already precompiled.
Precompiling LuxLibEnzymeExt...
1261.1 ms ✓ LuxLib → LuxLibEnzymeExt
1 dependency successfully precompiled in 2 seconds. 132 already precompiled.
Precompiling LuxEnzymeExt...
6642.3 ms ✓ Lux → LuxEnzymeExt
1 dependency successfully precompiled in 7 seconds. 148 already precompiled.
Precompiling OptimizationEnzymeExt...
12966.4 ms ✓ OptimizationBase → OptimizationEnzymeExt
1 dependency successfully precompiled in 13 seconds. 111 already precompiled.
Precompiling MLDataDevicesTrackerExt...
1163.8 ms ✓ MLDataDevices → MLDataDevicesTrackerExt
1 dependency successfully precompiled in 1 seconds. 60 already precompiled.
Precompiling LuxLibTrackerExt...
1070.6 ms ✓ LuxCore → LuxCoreArrayInterfaceTrackerExt
3273.9 ms ✓ LuxLib → LuxLibTrackerExt
2 dependencies successfully precompiled in 4 seconds. 101 already precompiled.
Precompiling LuxTrackerExt...
2047.7 ms ✓ Lux → LuxTrackerExt
1 dependency successfully precompiled in 2 seconds. 115 already precompiled.
Precompiling ComponentArraysTrackerExt...
1152.0 ms ✓ ComponentArrays → ComponentArraysTrackerExt
1 dependency successfully precompiled in 1 seconds. 71 already precompiled.
Precompiling MLDataDevicesReverseDiffExt...
3372.6 ms ✓ MLDataDevices → MLDataDevicesReverseDiffExt
1 dependency successfully precompiled in 4 seconds. 49 already precompiled.
Precompiling LuxLibReverseDiffExt...
3293.3 ms ✓ LuxCore → LuxCoreArrayInterfaceReverseDiffExt
4114.8 ms ✓ LuxLib → LuxLibReverseDiffExt
2 dependencies successfully precompiled in 4 seconds. 99 already precompiled.
Precompiling ComponentArraysReverseDiffExt...
3421.5 ms ✓ ComponentArrays → ComponentArraysReverseDiffExt
1 dependency successfully precompiled in 4 seconds. 57 already precompiled.
Precompiling OptimizationReverseDiffExt...
3277.0 ms ✓ OptimizationBase → OptimizationReverseDiffExt
1 dependency successfully precompiled in 4 seconds. 122 already precompiled.
Precompiling LuxReverseDiffExt...
4250.3 ms ✓ Lux → LuxReverseDiffExt
1 dependency successfully precompiled in 5 seconds. 116 already precompiled.
Precompiling MLDataDevicesZygoteExt...
1536.8 ms ✓ MLDataDevices → MLDataDevicesGPUArraysExt
1554.3 ms ✓ MLDataDevices → MLDataDevicesZygoteExt
2 dependencies successfully precompiled in 2 seconds. 109 already precompiled.
Precompiling LuxZygoteExt...
1637.2 ms ✓ WeightInitializers → WeightInitializersGPUArraysExt
2782.6 ms ✓ Lux → LuxZygoteExt
2 dependencies successfully precompiled in 3 seconds. 167 already precompiled.
Precompiling ComponentArraysZygoteExt...
1578.7 ms ✓ ComponentArrays → ComponentArraysZygoteExt
1816.4 ms ✓ ComponentArrays → ComponentArraysGPUArraysExt
2 dependencies successfully precompiled in 2 seconds. 117 already precompiled.
Precompiling OptimizationZygoteExt...
2084.0 ms ✓ OptimizationBase → OptimizationZygoteExt
1 dependency successfully precompiled in 2 seconds. 162 already precompiled.
Precompiling CairoMakie...
378.6 ms ✓ IndirectArrays
401.7 ms ✓ PolygonOps
414.0 ms ✓ Contour
423.1 ms ✓ GeoFormatTypes
417.1 ms ✓ PCRE2_jll
416.7 ms ✓ LazyModules
434.1 ms ✓ TriplotBase
441.9 ms ✓ TensorCore
453.0 ms ✓ StableRNGs
479.4 ms ✓ Extents
490.7 ms ✓ RoundingEmulator
501.2 ms ✓ Observables
565.1 ms ✓ TranscodingStreams
346.8 ms ✓ CRC32c
458.0 ms ✓ Inflate
786.9 ms ✓ Grisu
1045.6 ms ✓ Format
878.7 ms ✓ Glib_jll
1519.0 ms ✓ AdaptivePredicates
1046.0 ms ✓ GeoInterface
786.0 ms ✓ Cairo_jll
1781.4 ms ✓ ColorVectorSpace
820.3 ms ✓ HarfBuzz_jll
738.1 ms ✓ ColorVectorSpace → SpecialFunctionsExt
2485.9 ms ✓ IntervalArithmetic
505.0 ms ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
760.4 ms ✓ libass_jll
789.4 ms ✓ Pango_jll
954.3 ms ✓ FFMPEG_jll
1289.6 ms ✓ Cairo
3436.8 ms ✓ ColorSchemes
3562.2 ms ✓ ExactPredicates
7883.6 ms ✓ Automa
9544.6 ms ✓ GeometryBasics
5222.3 ms ✓ DelaunayTriangulation
1029.5 ms ✓ Packing
1290.5 ms ✓ ShaderAbstractions
1895.7 ms ✓ FreeTypeAbstraction
7886.3 ms ✓ PlotUtils
3745.1 ms ✓ MakieCore
5038.7 ms ✓ GridLayoutBase
13962.7 ms ✓ ImageCore
2037.8 ms ✓ ImageBase
2352.3 ms ✓ WebP
3055.9 ms ✓ PNGFiles
3157.8 ms ✓ JpegTurbo
3272.1 ms ✓ Sixel
2140.3 ms ✓ ImageAxes
1222.8 ms ✓ ImageMetadata
8914.9 ms ✓ MathTeXEngine
1889.4 ms ✓ Netpbm
44109.8 ms ✓ TiffImages
1203.5 ms ✓ ImageIO
104735.9 ms ✓ Makie
80460.1 ms ✓ CairoMakie
55 dependencies successfully precompiled in 232 seconds. 217 already precompiled.
Precompiling ZygoteColorsExt...
1773.3 ms ✓ Zygote → ZygoteColorsExt
1 dependency successfully precompiled in 2 seconds. 106 already precompiled.
Precompiling DiffEqBaseUnitfulExt...
1305.7 ms ✓ DiffEqBase → DiffEqBaseUnitfulExt
1 dependency successfully precompiled in 2 seconds. 98 already precompiled.
Precompiling NNlibFFTWExt...
900.6 ms ✓ NNlib → NNlibFFTWExt
1 dependency successfully precompiled in 1 seconds. 54 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
524.1 ms ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
715.0 ms ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
2 dependencies successfully precompiled in 1 seconds. 42 already precompiled.
Precompiling IntervalArithmeticRecipesBaseExt...
853.8 ms ✓ IntervalArithmetic → IntervalArithmeticRecipesBaseExt
1 dependency successfully precompiled in 1 seconds. 31 already precompiled.
Precompiling SciMLBaseMakieExt...
8238.8 ms ✓ SciMLBase → SciMLBaseMakieExt
1 dependency successfully precompiled in 9 seconds. 306 already precompiled.
Define some Utility Functions
Tip
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector
function one2two(path, m₁, m₂)
M = m₁ + m₂
r₁ = m₂ / M .* path
r₂ = -m₁ / M .* path
return r₁, r₂
end
one2two (generic function with 1 method)
Next we define a function to perform the change of variables:
@views function soln2orbit(soln, model_params = nothing)
@assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
if size(soln, 1) == 2
χ = soln[1, :]
ϕ = soln[2, :]
@assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
p, M, e = model_params
else
χ = soln[1, :]
ϕ = soln[2, :]
p = soln[3, :]
e = soln[4, :]
end
r = p ./ (1 .+ e .* cos.(χ))
x = r .* cos.(ϕ)
y = r .* sin.(ϕ)
orbit = vcat(x', y')
return orbit
end
soln2orbit (generic function with 2 methods)
This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d_dt(v::AbstractVector, dt)
a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
b = (v[3:end] .- v[1:(end - 2)]) / 2
c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
return [a; b; c] / dt
end
d_dt (generic function with 1 method)
This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d2_dt2(v::AbstractVector, dt)
a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)
Now we define a function to compute the trace-free moment tensor from the orbit
function orbit2tensor(orbit, component, mass = 1)
x = orbit[1, :]
y = orbit[2, :]
Ixx = x .^ 2
Iyy = y .^ 2
Ixy = x .* y
trace = Ixx .+ Iyy
if component[1] == 1 && component[2] == 1
tmp = Ixx .- trace ./ 3
elseif component[1] == 2 && component[2] == 2
tmp = Iyy .- trace ./ 3
else
tmp = Ixy
end
return mass .* tmp
end
function h_22_quadrupole_components(dt, orbit, component, mass = 1)
mtensor = orbit2tensor(orbit, component, mass)
mtensor_ddot = d2_dt2(mtensor, dt)
return 2 * mtensor_ddot
end
function h_22_quadrupole(dt, orbit, mass = 1)
h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
return h11, h12, h22
end
function h_22_strain_one_body(dt::T, orbit) where {T}
h11, h12, h22 = h_22_quadrupole(dt, orbit)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
h11 = h11_1 + h11_2
h12 = h12_1 + h12_2
h22 = h22_1 + h22_2
return h11, h12, h22
end
function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
# compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2
@assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"
h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function compute_waveform(dt::T, soln, mass_ratio, model_params = nothing) where {T}
@assert mass_ratio ≤ 1 "mass_ratio must be <= 1"
@assert mass_ratio ≥ 0 "mass_ratio must be non-negative"
orbit = soln2orbit(soln, model_params)
if mass_ratio > 0
m₂ = inv(T(1) + mass_ratio)
m₁ = mass_ratio * m₂
orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
return waveform
end
compute_waveform (generic function with 2 methods)
Simulating the True Model
RelativisticOrbitModel
defines system of odes which describes motion of point like particle in schwarzschild background, uses
where,
function RelativisticOrbitModel(u, (p, M, e), t)
χ, ϕ = u
numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
denom = sqrt((p - 2)^2 - 4 * e^2)
χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = numer / (M * (p^(3 / 2)) * denom)
return [χ̇, ϕ̇]
end
mass_ratio = 0.0 # test particle
u0 = Float64[π, 0.0] # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4) # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length = datasize) # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e
Let's simulate the true model and plot the results using OrdinaryDiffEq.jl
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat = tsteps, dt, adaptive = false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel = "Time", ylabel = "Waveform")
l = lines!(ax, tsteps, waveform; linewidth = 2, alpha = 0.75)
s = scatter!(ax, tsteps, waveform; marker = :circle, markersize = 12, alpha = 0.5)
axislegend(ax, [[l, s]], ["Waveform Data"])
fig
end
Defiing a Neural Network Model
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model
that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but incase you do use them, make sure to mark them as const
.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl
,
const nn = Chain(
Base.Fix1(fast_activation, cos),
Dense(1 => 32, cos; init_weight = truncated_normal(; std = 1.0e-4), init_bias = zeros32),
Dense(32 => 32, cos; init_weight = truncated_normal(; std = 1.0e-4), init_bias = zeros32),
Dense(32 => 2; init_weight = truncated_normal(; std = 1.0e-4), init_bias = zeros32)
)
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[1.1684905f-5; -4.7562342f-5; -9.8346456f-5; 9.2707414f-5; -5.5832385f-5; 5.513598f-5; -5.315212f-5; 0.00016854524; -0.00019818037; 4.0237865f-5; 6.3435764f-6; 6.203809f-5; 7.028735f-5; -8.3079656f-5; 0.000116631614; 2.3536568f-6; -7.852616f-5; -1.45475415f-5; 0.00014218158; -8.9627465f-5; -5.7716214f-5; -4.2395473f-5; 5.547956f-5; 3.786926f-5; 8.198459f-5; -2.1605754f-5; -0.00019508181; 9.921214f-7; -0.00014845224; 0.00025220696; 0.00012981896; 0.00010813904;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[2.088722f-5 -1.0384153f-5 -9.923339f-5 -0.00023337625 -6.719556f-5 0.00011228717 -2.7750459f-5 -0.00013735746 -9.937435f-5 -1.6280892f-5 6.6701105f-5 -4.021812f-6 -8.6667635f-5 2.8168848f-5 7.4921976f-5 5.6510116f-6 4.4864424f-5 5.2468844f-5 -7.7628385f-5 -2.2529988f-5 -6.2222307f-7 4.8999136f-6 -1.5005712f-5 0.00011624689 -0.00012563949 -0.00023816775 0.00011265791 0.00010839347 4.4798162f-5 5.2900505f-5 0.00013163808 9.073432f-5; -5.9917034f-5 6.0145398f-5 5.0256785f-5 6.9250564f-5 -4.2841f-5 -0.00011384397 7.8834426f-5 0.00013481764 0.000104489096 0.00011217392 -2.8544935f-5 -0.00012977341 9.19914f-5 -1.2791091f-5 -1.1380922f-5 6.376344f-5 -0.00012432106 5.937397f-5 -5.766167f-6 -6.6287226f-5 -0.00016995014 -1.52461525f-5 5.8879883f-5 2.2334365f-5 -2.5447189f-5 -0.00012901193 -1.0667334f-5 -5.964425f-5 -8.5390355f-5 0.00018054189 -2.7473172f-5 9.24942f-5; 0.00015035793 0.00013564333 4.5188306f-5 2.025252f-5 -7.2929633f-6 3.6066842f-5 2.0440773f-5 -6.8673735f-6 -7.81625f-5 0.000100382495 7.814525f-5 5.624272f-5 2.5178277f-5 -8.600803f-6 0.00010461494 8.415992f-6 -0.00014457712 5.4429114f-5 6.919445f-5 -0.00020711031 -0.00021791106 8.212324f-5 8.130227f-5 4.0735235f-5 -4.417524f-6 -2.117888f-5 -5.716974f-5 -9.656721f-5 0.00014454452 -0.0001850033 0.00019189817 0.00011108199; -6.108972f-5 -0.000109225664 0.00021448385 7.722318f-5 -4.049551f-5 -0.00014348871 0.00012652583 3.7157438f-8 -2.088835f-5 1.8641506f-6 -7.487372f-5 -1.2506324f-5 1.7930457f-5 4.0512086f-6 -6.356351f-5 2.1949558f-5 -5.950191f-5 -2.8600976f-5 0.00021861444 -9.474737f-6 -3.390166f-5 -3.7145917f-5 -7.526067f-5 0.000117903546 -7.879762f-5 8.10823f-5 0.00010687891 -1.4865599f-5 -4.841093f-5 0.00016363944 -5.7125017f-5 3.232548f-5; 4.6496203f-5 -7.0738584f-5 -2.8979389f-6 -0.00021087655 2.350714f-5 0.00010442293 3.5953303f-7 3.2616765f-5 -8.640114f-5 -3.786721f-7 0.0001103531 -7.6921024f-5 3.483132f-5 -0.00013451334 0.000109874134 2.6539385f-6 -1.8588718f-5 -4.1370895f-5 2.5726038f-6 -6.1614104f-5 0.00010754753 -6.403596f-6 2.6972091f-5 -7.405993f-5 -6.597882f-5 -0.000121237004 -0.00011906455 -2.7027454f-5 -0.00016030221 4.389199f-6 2.5633497f-5 2.1793736f-5; 4.4917917f-5 4.681705f-5 1.1736051f-5 -5.8095982f-5 -0.00012784364 9.507624f-5 0.00015312394 7.645594f-5 -1.4698146f-5 2.1220058f-5 -7.457968f-5 -2.836059f-5 0.00010177395 -2.5321466f-5 -0.00017961078 8.017614f-5 -0.00016125012 0.00018662545 0.00015553633 -9.442007f-6 6.758729f-5 2.5059857f-5 1.4374657f-5 0.00010774504 2.2928718f-5 9.3382754f-5 8.111662f-5 1.6176327f-5 3.3414515f-5 4.6632453f-5 -6.97374f-5 9.259694f-5; -1.9478028f-5 -0.00010730424 0.00012286246 -5.6914545f-5 0.00010645061 0.00020565206 -1.3026944f-5 2.5445072f-5 8.4614316f-5 -0.00016345148 -5.857426f-5 -8.962069f-5 0.00010963009 -1.7391964f-6 -6.905813f-5 -3.4827477f-5 0.00013861911 -1.0855554f-5 6.3835694f-5 -0.00010585937 -3.4714532f-5 0.00014418727 -5.8282098f-5 4.3082204f-5 5.9974614f-6 -2.0596915f-5 -1.0628148f-5 -8.515751f-5 8.221581f-5 6.0191367f-5 0.00017254188 1.4208019f-5; -9.793665f-5 9.85109f-5 -0.00013970058 1.9002088f-6 -0.00011692633 -3.3321852f-5 -9.675481f-5 -0.00010422697 3.0022064f-5 -0.00015108263 1.0254031f-5 8.69494f-5 1.0576778f-5 -1.3144847f-5 -1.1405285f-5 0.000101609476 6.872323f-5 -4.6859037f-5 -4.514674f-5 -2.6893713f-5 -0.00010194663 -5.1778632f-5 9.962426f-5 6.4364496f-5 3.189322f-5 -5.504246f-5 -4.3336826f-5 1.9031831f-5 6.1344086f-5 5.0101935f-5 -9.705026f-5 -2.9733777f-5; -9.650133f-5 4.1168387f-6 -2.8238323f-6 -2.430767f-5 9.388309f-5 3.5748755f-5 0.00013382795 0.00012501968 -8.6422f-5 4.2435084f-5 -1.4031241f-6 -0.00016365142 -5.938125f-5 3.810653f-5 -0.00022335823 -4.761851f-5 -2.7196187f-5 -5.96799f-5 5.070475f-5 5.0023715f-5 -2.4601006f-5 -8.2677136f-5 -0.00016701513 2.0686764f-5 6.0744984f-5 3.6950105f-6 -2.682923f-5 -0.00020156136 -0.0003043549 5.9197115f-5 -0.00011349351 -6.613376f-5; 6.894893f-5 6.588276f-5 -5.356462f-5 6.814854f-6 -0.0001754245 -0.00012160799 -9.635787f-5 9.650914f-6 4.911041f-5 -4.9894836f-5 5.2904165f-6 0.000108406515 -3.87966f-5 0.00015994598 -5.4891403f-5 -0.00016180566 -7.483392f-5 -1.6531298f-5 -6.866813f-5 0.00015062747 0.0002472933 6.797544f-5 0.00010658516 -0.00012163704 -6.80141f-5 -1.8545854f-5 -0.0001241952 1.4772824f-6 6.784644f-5 -0.00012581753 2.6952694f-5 -6.372112f-6; -0.00013655233 3.7522732f-5 0.00018782253 4.687737f-6 -0.00011081433 2.0609612f-5 -0.0001777138 -0.00015084725 1.9730032f-5 9.007675f-5 -5.7420588f-5 1.0713254f-5 0.0001615104 -0.00010087456 -8.372983f-5 4.691307f-5 -1.3770388f-5 7.275328f-5 1.3486294f-5 7.704387f-5 -8.012146f-5 0.00013842934 -0.00013188546 -2.8233819f-5 4.339041f-5 -3.769436f-5 0.00015222901 9.1477756f-5 -5.630163f-5 3.9359606f-5 -8.705842f-5 6.760031f-5; 6.163674f-5 -2.8829414f-5 -1.8239165f-5 -3.0283236f-5 -9.3758135f-5 -9.55332f-5 3.447313f-5 0.0001638876 5.101372f-5 1.2703637f-5 4.5778885f-5 0.00011647861 -9.642962f-5 2.7403887f-5 -0.00010628428 -2.8460927f-5 0.00016994674 6.217895f-5 7.4884716f-5 -6.6666357f-6 -9.00689f-5 0.00012575842 1.3624942f-5 6.2689236f-5 8.0599086f-5 3.7000285f-5 1.921279f-6 -0.00016597233 7.059613f-5 7.6187683f-7 -5.302445f-5 -6.9140566f-5; -0.00017198925 8.337728f-6 -3.81572f-5 -0.00015338157 0.00012463053 0.00020534146 9.807124f-6 -9.400505f-5 9.239298f-5 4.128268f-5 7.536292f-5 7.4753894f-5 -9.117339f-5 -0.00014074608 -0.00013111917 -0.00016392863 -0.00012497004 0.00015172042 -0.00010060856 0.00012660786 7.676104f-5 -9.60528f-5 0.00011592054 0.00011354655 2.5579095f-5 6.36368f-5 3.7376783f-5 -0.00013659117 -0.00012450157 -2.5330131f-5 3.835039f-5 2.0366699f-5; -0.00011352322 -0.00029400724 -0.00014187387 1.8246166f-5 6.515616f-6 -1.6686854f-5 -2.6123662f-5 2.614345f-5 -9.303158f-5 -5.2156913f-5 2.1189771f-5 -6.927444f-5 2.3605682f-5 7.4393596f-5 2.3380811f-5 3.240877f-5 -0.00015396149 -0.00013277736 0.0002172182 6.079936f-5 2.8649636f-5 0.0001943015 -8.612096f-5 8.098408f-5 -8.4665415f-5 -7.629479f-5 -3.6916976f-5 -0.0001895155 0.0001015209 1.2828856f-6 -2.317976f-5 -1.911416f-5; -9.9066434f-5 0.00016756798 0.00010608431 -0.00013052842 2.6664935f-5 0.000102594226 1.815873f-5 -8.787213f-5 -5.1714676f-5 5.6882767f-5 -0.00012741318 -5.509213f-5 -6.9805625f-5 1.6116175f-5 -8.277243f-5 -1.6466005f-5 0.00019177629 -0.00011531458 6.1748986f-5 2.2701156f-6 -1.913709f-5 -1.3781174f-5 -9.44984f-6 9.9331875f-5 3.363884f-5 -0.00027377106 -2.1996577f-5 -6.243668f-5 7.4362447f-6 0.00010001255 1.3472996f-5 2.5645159f-5; -9.716363f-5 -9.369649f-5 -2.0494351f-6 -1.0757164f-5 0.00012251668 -0.00022005131 -0.00012229924 -3.517816f-5 5.0233848f-5 -5.631348f-5 -3.710014f-5 6.337054f-5 0.00017728902 -0.00021947664 -2.7259945f-5 8.6098466f-5 8.553158f-5 -0.00017367571 -0.0001051632 6.855217f-5 3.1212236f-5 4.539996f-6 0.00017897775 -2.9216517f-7 2.4274692f-5 -6.49952f-6 0.0002657032 0.00017111246 2.4510406f-5 -3.9071925f-5 7.1243826f-6 1.5936914f-5; 9.4743235f-5 4.613115f-5 -0.0002372199 4.7830665f-5 9.181537f-5 3.032763f-5 -5.033251f-5 -4.219445f-5 -0.000118944 9.3820636f-5 6.245642f-5 5.0194876f-7 2.5728323f-5 2.079164f-5 7.979552f-5 0.0002007488 0.00012969869 -5.710752f-6 -1.1681775f-5 -3.7868685f-5 -2.19601f-5 7.475647f-5 -2.0025482f-5 5.8150148f-5 2.0468127f-5 0.00011794872 -0.0001309858 -4.875218f-5 0.00011613721 -2.2602374f-6 0.00020198792 8.7165035f-6; -2.615815f-7 0.00012281672 7.151949f-5 2.2268687f-5 6.343525f-5 0.00017087242 -0.00019616967 -2.7473061f-5 -6.7667956f-5 -6.887006f-5 -1.710407f-5 -8.900492f-5 -9.44963f-6 -0.00015237012 -0.00012206058 -5.1651135f-5 -5.2316824f-5 -0.00022294206 -1.7498338f-5 -6.6832574f-5 4.79513f-5 7.617698f-5 0.000107126856 -3.990855f-5 -9.667947f-5 -1.8703668f-5 1.7795313f-5 -3.865066f-5 -5.858509f-5 4.082475f-5 -4.345658f-5 -0.0001479024; 0.00015671342 0.00011593369 3.573354f-5 9.384737f-6 0.00015108759 2.1280528f-5 -3.8133407f-5 -4.4229728f-5 2.692922f-5 -0.00013637904 3.544943f-5 0.00010651964 0.000110070556 -7.7545104f-5 6.5137225f-5 5.7342397f-5 -4.3638884f-5 1.3848266f-5 4.1454714f-5 7.925176f-5 7.660025f-5 -0.00019993498 -6.98933f-5 0.000106639476 6.95927f-5 0.00012430256 7.0666865f-5 -1.7073491f-5 -9.8446835f-5 -0.0001025357 4.335825f-6 0.00019166672; -0.0001732263 7.298334f-5 -2.1027545f-5 -0.00017605867 -5.6743415f-6 2.8378226f-5 0.00016427132 -0.00026247487 0.00033318848 6.5579145f-5 -6.131681f-5 -1.2838271f-5 -0.0001258596 5.6244622f-5 -9.5335046f-5 -0.00016105559 -1.0383943f-5 8.131914f-5 0.00013746694 -3.4544108f-5 -2.08159f-5 5.05957f-5 1.8042745f-5 -2.4385221f-5 -2.689238f-5 -5.5477205f-5 5.8216152f-5 -0.000109980596 -9.905133f-5 -0.00012038238 3.0687565f-6 -0.00011114699; -7.036333f-5 8.678109f-5 2.5527273f-5 4.0151335f-5 9.326175f-5 3.5126985f-5 2.5920593f-7 -0.0001477559 0.00013994396 -5.2520067f-5 0.00018048167 -2.1655396f-5 2.4807201f-5 -3.2108346f-5 1.0994176f-5 -0.00016451656 -4.4062646f-5 2.154061f-5 -3.139748f-5 0.00012607245 -6.941258f-5 -4.0293166f-5 6.90361f-5 0.00022465832 -2.6056935f-6 0.00013417868 7.804854f-5 5.514754f-5 2.8794045f-6 -3.1663272f-5 -3.735181f-5 -9.954036f-5; 1.022892f-5 -6.448299f-5 -0.00014066426 8.0163976f-5 7.5732474f-5 2.9489489f-5 -7.4807205f-5 -5.698566f-5 -2.9339439f-5 -0.00013606611 -0.00014739849 -5.4692136f-5 7.607685f-5 -3.1953452f-5 -9.931016f-5 -6.403767f-5 0.00016295342 6.842537f-5 5.6956505f-5 0.00017460124 0.00016460204 -4.6016007f-6 -0.00024530047 -1.9600222f-5 -0.00011705401 -0.000127548 -8.831056f-5 -3.5284582f-5 0.00020077915 0.00014877856 5.3398195f-5 1.9033978f-5; -8.1479964f-5 0.00015065075 5.388123f-5 -4.038887f-5 4.760108f-5 -0.00010611966 -6.196215f-5 -2.9852592f-5 -7.457004f-5 5.5773573f-5 -0.00010900844 0.00014127578 9.336244f-5 -4.624992f-5 -7.029557f-5 5.8275513f-5 -9.641412f-6 -1.9633271f-5 -8.8899884f-5 4.7952212f-6 -3.925218f-5 -9.871096f-5 3.628659f-6 -6.656003f-5 -3.613817f-5 -4.516891f-5 6.869785f-5 3.319269f-5 -7.3369876f-5 -9.2004433f-7 3.545974f-5 -0.00014195331; 0.00011975481 -0.00017214511 -0.00027085352 -3.838662f-6 -4.483196f-5 0.00015148091 6.205813f-5 7.260404f-5 1.8663832f-5 -0.00011857501 -4.5813926f-5 0.00012582108 -0.0001323008 0.00010438747 9.779268f-5 -6.739874f-6 5.1547842f-5 -1.4675281f-5 1.0711752f-5 9.81369f-5 3.098993f-5 -6.138838f-5 -0.00014040747 -4.1670595f-5 -3.373663f-5 -1.600881f-5 -4.602829f-5 2.5250103f-5 0.00026016918 -5.0770497f-5 -0.00018628404 -2.384383f-6; -8.733175f-5 -2.5002553f-5 4.306396f-5 5.925154f-6 0.00012902381 1.7724797f-5 0.00016536802 -1.669564f-5 7.055523f-5 6.321581f-5 -6.415144f-5 -0.000250211 4.2493994f-5 -0.00014041163 1.8081682f-5 -0.00017742261 0.00015957524 -0.00016142758 -8.5948435f-5 -7.6482764f-5 -2.6299942f-5 -0.00010412529 -1.3330716f-5 0.0001125541 1.51088f-5 -0.00011306261 4.976699f-5 -1.0304333f-5 -0.00012647765 -9.29928f-6 -3.9780312f-5 -5.7594734f-5; -0.00017100819 1.2105732f-5 6.8457746f-5 3.2015374f-5 -1.4388199f-5 2.9092338f-5 -2.590788f-5 -0.00013806099 2.8870032f-5 9.141652f-5 -0.00011163715 -3.3016906f-5 -8.275127f-5 0.00016561814 -0.00021770061 6.7876324f-5 -9.4464114f-5 4.8596725f-5 0.00017088227 6.166271f-5 -0.00014094119 0.00010312514 4.472186f-6 1.1132412f-5 8.657137f-5 2.6038782f-5 2.497923f-5 6.573373f-5 0.0002254918 2.1776137f-5 0.0001786482 -4.9326227f-5; 4.567778f-5 -1.6480424f-5 0.00013995642 -0.00027588516 -3.6285262f-5 -0.00012927884 -0.0001762962 -8.023587f-5 7.1966104f-5 -6.476933f-5 0.00012292774 -0.00019989409 -0.00022472192 1.6717897f-5 -2.5023935f-5 -4.4514465f-5 -5.4643646f-5 -3.394325f-5 -1.7248947f-5 6.296457f-5 -6.225874f-5 -1.0178501f-5 -5.769465f-5 7.685066f-5 -8.438582f-5 0.00011126813 6.278124f-5 -3.0965846f-5 -1.1172117f-5 0.00015099054 -2.125105f-5 0.0001619129; -4.7056466f-5 0.00010958509 -9.91231f-6 6.2984465f-5 5.4546184f-5 -6.8054585f-5 -2.2093664f-6 6.9947266f-5 -2.801772f-5 -6.124044f-5 5.9248432f-5 6.5161994f-6 -3.4490477f-5 6.5756416f-5 5.983356f-5 0.00015670297 6.562369f-5 3.6298396f-5 2.5917338f-6 -6.20962f-5 -0.00017462802 -0.00011205958 -0.0001419291 -0.00013755949 -2.9821127f-5 -6.0488244f-5 7.276058f-5 -0.00010927237 -0.00017524176 0.00010117306 -0.00014609257 -0.00013354831; 9.701471f-6 -0.00021020457 -8.492709f-5 9.705451f-5 -8.8903376f-5 0.00019871962 -0.00012423377 -2.456979f-5 5.0564446f-5 2.128747f-5 -6.129259f-5 0.00013128867 -0.0002384069 7.253606f-5 0.000101914884 0.0002388846 -9.777445f-5 -5.4337164f-5 9.276706f-5 -2.9588255f-5 -2.353557f-5 -0.000189558 0.00012128255 2.7677133f-5 9.9276085f-6 0.00010896452 -8.608624f-5 4.0671188f-5 6.3170235f-5 -5.610703f-5 -0.0001584691 3.0928723f-7; 2.8117689f-5 0.00010102379 -0.00015527935 2.8370869f-5 -0.00012203356 0.00018187817 9.3108594f-7 4.1563344f-5 -1.2366614f-5 -4.320309f-5 -0.00017092461 -3.8372864f-6 2.8505283f-5 -4.563044f-5 -3.4902765f-5 -6.625029f-5 7.051068f-5 -2.6489653f-5 9.0381065f-5 0.00011449806 -6.004659f-5 0.00017742296 0.00011757528 -9.368483f-5 -0.00016357108 4.1729185f-5 -8.868677f-5 0.00012994872 -5.135045f-5 -4.1458175f-6 -6.5738386f-5 -0.0001571745; 7.3763855f-5 7.2021714f-5 0.00016850785 -8.880169f-5 0.00010385924 9.436527f-8 -9.269683f-5 0.000119656164 4.5442404f-5 2.9012914f-5 5.2649525f-6 -2.5329547f-5 -2.9863522f-5 -2.3064867f-5 -0.00012244273 0.00012623367 -2.1539094f-5 -9.8348784f-5 -3.0480003f-5 -2.2350112f-5 -3.472104f-5 -0.0001283905 -8.406765f-5 -4.4036224f-5 -4.0759667f-5 -1.8734545f-5 -9.986437f-6 -5.342516f-5 -2.200094f-5 0.00010005978 4.3087377f-5 -0.00022177344; 4.2569016f-5 5.7517904f-5 5.169262f-6 -0.00012657372 7.546552f-5 -3.003708f-5 -7.026017f-5 0.00014579915 6.3619955f-5 -7.878464f-5 0.00017665219 1.5169947f-5 6.212588f-5 -8.662644f-5 0.00012255202 -0.00017744473 8.372976f-5 0.0002037982 -3.710257f-5 -9.869019f-5 -0.0001738194 -6.240207f-5 -7.318418f-5 0.00015507065 -4.5888282f-6 8.256229f-5 -0.00015783629 2.770236f-5 -1.2132943f-5 -0.00018456898 -0.00018564233 -3.1433254f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[5.421233f-5 9.0343114f-5 -6.1137165f-5 1.1732396f-5 3.0139989f-5 -0.00015654134 3.8321417f-5 -0.00011828456 0.00017487114 1.4428792f-5 -0.0001880524 -5.14773f-5 -5.779915f-5 6.3310195f-5 -3.1628504f-5 -6.0035698f-5 0.00010111617 -0.00013916024 3.656961f-5 -0.00024806138 0.00014512321 -0.00017636962 2.4420944f-5 6.447041f-5 -3.209032f-5 9.8344035f-6 -0.00016778137 -0.00014129494 2.5165586f-5 9.9588724f-5 9.583881f-5 -4.7820107f-5; 0.0002066463 0.00012387565 1.7134224f-5 5.0449245f-5 -3.2951993f-5 7.038855f-5 -0.00012783046 5.3425723f-5 -4.4522487f-5 2.4328769f-5 -5.906376f-5 0.0001300035 1.5695308f-5 -0.000107534644 0.00011151003 2.1290754f-5 -0.00023819199 -2.1274402f-5 -0.00010831961 -0.00015910916 -0.00010208733 5.453083f-5 -0.00017767944 2.007757f-5 -5.768941f-5 -4.6838904f-5 -0.00011441517 -2.0624966f-5 4.495837f-5 0.00010818863 -3.067845f-6 6.9803435f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))
Similar to most DL frameworks, Lux defaults to using Float32
, however, in this case we need Float64
const params = ComponentArray(ps |> f64)
const nn_model = StatefulLuxLayer{true}(nn, nothing, st)
StatefulLuxLayer{true}(
Chain(
layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
layer_2 = Dense(1 => 32, cos), # 64 parameters
layer_3 = Dense(32 => 32, cos), # 1_056 parameters
layer_4 = Dense(32 => 2), # 66 parameters
),
) # Total: 1_186 parameters,
# plus 0 states.
Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses
where,
function ODE_model(u, nn_params, t)
χ, ϕ = u
p, M, e = ode_model_params
# In this example we know that `st` is am empty NamedTuple hence we can safely ignore
# it, however, in general, we should use `st` to store the state of the neural network.
y = 1 .+ nn_model([first(u)], nn_params)
numer = (1 + e * cos(χ))^2
denom = M * (p^(3 / 2))
χ̇ = (numer / denom) * y[1]
ϕ̇ = (numer / denom) * y[2]
return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)
Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p = params, saveat = tsteps, dt, adaptive = false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel = "Time", ylabel = "Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth = 2, alpha = 0.75)
s1 = scatter!(
ax, tsteps, waveform; marker = :circle, markersize = 12, alpha = 0.5, strokewidth = 2
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth = 2, alpha = 0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker = :circle, markersize = 12, alpha = 0.5, strokewidth = 2
)
axislegend(
ax, [[l1, s1], [l2, s2]],
["Waveform Data", "Waveform Neural Net (Untrained)"]; position = :lb
)
fig
end
Setting Up for Training the Neural Network
Next, we define the objective (loss) function to be minimized when training the neural differential equations.
const mseloss = MSELoss()
function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p = θ, saveat = tsteps, dt, adaptive = false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
return mseloss(pred_waveform, waveform)
end
loss (generic function with 1 method)
Warmup the loss function
loss(params)
0.0006864946400429274
Now let us define a callback function to store the loss over time
const losses = Float64[]
function callback(θ, l)
push!(losses, l)
@printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
return false
end
callback (generic function with 1 method)
Training the Neural Network
Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
optprob, BFGS(; initial_stepnorm = 0.01, linesearch = LineSearches.BackTracking());
callback, maxiters = 1000
)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [1.1684905075510235e-5; -4.756234193336002e-5; -9.834645607041006e-5; 9.27074142963352e-5; -5.583238453256071e-5; 5.513598080143146e-5; -5.3152118198318194e-5; 0.00016854524437804217; -0.00019818036525975926; 4.023786459582534e-5; 6.343576387734276e-6; 6.203808879937171e-5; 7.028735126360051e-5; -8.307965617857265e-5; 0.00011663161421886803; 2.353656782358725e-6; -7.852615817684807e-5; -1.4547541468323432e-5; 0.00014218158321448301; -8.962746505859295e-5; -5.7716213632383133e-5; -4.23954734287463e-5; 5.547955879589149e-5; 3.7869260268019525e-5; 8.198458817778692e-5; -2.160575422745688e-5; -0.00019508181139819516; 9.92121385933282e-7; -0.00014845223631679104; 0.00025220695533806657; 0.00012981895997633323; 0.00010813903645594471;;], bias = [1.865859080332062e-17, -1.2824401977179955e-17, -2.1914020048586422e-17, 1.3953906849459077e-16, 1.172737118145792e-17, -7.63723918364325e-18, -2.6090242252946762e-17, -7.754440293700479e-19, 6.5590527410027875e-18, 1.447365033888795e-17, 9.213948889934352e-19, 9.5054922366098e-17, 1.0030537726226057e-16, -2.855907520598836e-17, 7.342817000470606e-17, 2.8958011037053682e-18, 2.2448218901465697e-17, -2.5934926002178136e-17, 7.769523381962897e-17, 1.0152856126791147e-17, 8.858032523348367e-18, -2.5924201151734986e-17, 3.6167871105813494e-17, 3.858662067995849e-17, 1.0144088290169518e-16, -4.250749332739879e-17, -7.996053400570479e-17, 1.1704354945032432e-18, -3.529262225762163e-16, -1.9489141108862974e-16, 2.4562664924127276e-16, 2.0552684131197199e-16]), layer_3 = (weight = [2.0886512883486433e-5 -1.0384860816805394e-5 -9.92340958154373e-5 -0.00023337696095123765 -6.719626455461609e-5 0.00011228646124353629 -2.775116647254686e-5 -0.0001373581710330711 -9.93750602182476e-5 -1.628159972907548e-5 6.670039672743454e-5 -4.022519734382961e-6 -8.666834239276088e-5 2.8168140287642277e-5 7.492126840505988e-5 5.650303845947317e-6 4.486371633819371e-5 5.2468136144931645e-5 -7.762909283056463e-5 -2.2530695988199503e-5 -6.229308573098254e-7 4.899205831940874e-6 -1.5006419945086865e-5 0.0001162461828963268 -0.00012564019763968806 -0.00023816845552656864 0.00011265720038781751 0.00010839276163013575 4.479745419220503e-5 5.289979688232301e-5 0.00013163737255953567 9.073361232284652e-5; -5.991680779548303e-5 6.0145624165349133e-5 5.025701205308536e-5 6.925079026554707e-5 -4.2840771921285996e-5 -0.00011384374272015841 7.883465253111944e-5 0.0001348178713029887 0.00010448932255680441 0.00011217414533316205 -2.8544708466464066e-5 -0.0001297731858997882 9.199162844165251e-5 -1.2790864157390234e-5 -1.138069535077591e-5 6.376366321059935e-5 -0.00012432083064627816 5.937419803526311e-5 -5.765940268570297e-6 -6.628699963617137e-5 -0.00016994991789975167 -1.5245925961019946e-5 5.888010958155654e-5 2.233459156709306e-5 -2.5446962236268356e-5 -0.00012901169872780275 -1.0667107626756084e-5 -5.9644024868569674e-5 -8.53901282156176e-5 0.00018054211353013812 -2.7472945357980863e-5 9.249442621662254e-5; 0.0001503598211817801 0.0001356452300930741 4.519020125342726e-5 2.0254415848905024e-5 -7.291068195913513e-6 3.60687370081772e-5 2.0442668519249776e-5 -6.865478335686872e-6 -7.816060338913245e-5 0.00010038438971576973 7.814714239401601e-5 5.624461458836304e-5 2.5180171851101535e-5 -8.598907934644801e-6 0.00010461683608780564 8.417887092662052e-6 -0.0001445752243492432 5.4431009393453945e-5 6.919634664852642e-5 -0.00020710841296051334 -0.00021790916458513932 8.212513397637506e-5 8.130416585089308e-5 4.073712965663295e-5 -4.415628846179194e-6 -2.1176984101032647e-5 -5.716784313366453e-5 -9.65653127971678e-5 0.00014454641835952533 -0.00018500140528008698 0.00019190006763863127 0.0001110838818382758; -6.108892722390552e-5 -0.00010922487291377052 0.0002144846423006596 7.722397111952125e-5 -4.049471846069246e-5 -0.00014348791756769677 0.00012652662388877917 3.794887025129504e-8 -2.0887558831470487e-5 1.8649420225789308e-6 -7.487293060971381e-5 -1.2505532878137278e-5 1.7931248466041038e-5 4.052000066578443e-6 -6.35627164329433e-5 2.195034920892257e-5 -5.9501116821636666e-5 -2.8600184860848428e-5 0.00021861523254205922 -9.473946021033222e-6 -3.3900867320923554e-5 -3.7145125871617367e-5 -7.525988058755209e-5 0.0001179043376295828 -7.879682544721418e-5 8.108308810638296e-5 0.00010687970478915602 -1.4864807609270351e-5 -4.841013816171851e-5 0.00016364023594081663 -5.712422517813203e-5 3.23262704013821e-5; 4.649388761815362e-5 -7.074089984901344e-5 -2.900254406378375e-6 -0.0002108788630205913 2.350482469107461e-5 0.00010442061289420317 3.572175107117096e-7 3.2614449134320165e-5 -8.640345362736941e-5 -3.8098762009056087e-7 0.00011035078756019212 -7.69233392414448e-5 3.482900324950103e-5 -0.00013451565059160555 0.00010987181854641827 2.6516229869870457e-6 -1.8591033898087068e-5 -4.1373210427381715e-5 2.570288240064776e-6 -6.161641971932383e-5 0.00010754521468395354 -6.405911738809781e-6 2.69697757691328e-5 -7.406224353265174e-5 -6.598113572483639e-5 -0.00012123931987237414 -0.0001190668644479029 -2.7029769574555406e-5 -0.00016030452565547548 4.386883626214593e-6 2.5631181467073785e-5 2.1791420346453444e-5; 4.492085381021722e-5 4.681998788688125e-5 1.1738988263413851e-5 -5.809304539620764e-5 -0.0001278407025038573 9.50791773847928e-5 0.00015312687637011258 7.645887407664268e-5 -1.469520876169732e-5 2.1222994377201837e-5 -7.457674192304717e-5 -2.835765272361121e-5 0.00010177688554354477 -2.5318528883915437e-5 -0.0001796078458083272 8.017907848075974e-5 -0.00016124718504765656 0.0001866283829203527 0.00015553926377282398 -9.439069714106821e-6 6.759022389265396e-5 2.5062793696595594e-5 1.4377594283320411e-5 0.00010774798019516155 2.2931654797460087e-5 9.338569097413647e-5 8.111955421001706e-5 1.6179264179263384e-5 3.341745204482902e-5 4.6635389566192466e-5 -6.973446116358728e-5 9.25998802758362e-5; -1.947638547447107e-5 -0.0001073025944572053 0.00012286410283537943 -5.6912902706456846e-5 0.00010645224911401756 0.00020565370296369087 -1.302530135601965e-5 2.5446713792693948e-5 8.461595798848845e-5 -0.00016344983514157268 -5.857261775984487e-5 -8.961904886403383e-5 0.00010963173345205054 -1.737554085805179e-6 -6.905648687846808e-5 -3.482583478783314e-5 0.0001386207565735894 -1.085391185936585e-5 6.383733644157448e-5 -0.0001058577275182262 -3.471289009778993e-5 0.00014418891596848924 -5.8280455319827744e-5 4.3083846052259746e-5 5.999103635774749e-6 -2.0595273141433355e-5 -1.062650544213351e-5 -8.515586732412744e-5 8.221745311289728e-5 6.019300936144128e-5 0.00017254352651776726 1.4209661188728845e-5; -9.793854427700486e-5 9.850900802119212e-5 -0.00013970247549879327 1.8983167910703937e-6 -0.00011692821884937024 -3.332374421008306e-5 -9.675670320577376e-5 -0.00010422886429296533 3.0020171616347666e-5 -0.0001510845179291395 1.0252138745252955e-5 8.694751137225564e-5 1.057488567310595e-5 -1.3146738856464103e-5 -1.1407176534714614e-5 0.00010160758371055494 6.872133941213422e-5 -4.686092892372599e-5 -4.5148630524643275e-5 -2.6895604823963424e-5 -0.00010194852096903494 -5.178052400880945e-5 9.962236777935736e-5 6.436260424377854e-5 3.189132963596698e-5 -5.504435124750783e-5 -4.3338717878210027e-5 1.902993950350578e-5 6.134219416707631e-5 5.010004293582999e-5 -9.705215074064655e-5 -2.9735669377272245e-5; -9.650537241340354e-5 4.112797695245397e-6 -2.8278733328777014e-6 -2.431171115984239e-5 9.387905164532085e-5 3.574471419471026e-5 0.00013382391342803515 0.0001250156391403382 -8.642603926212739e-5 4.243104275575363e-5 -1.4071650742976495e-6 -0.00016365546172108684 -5.938529203502503e-5 3.810248825959145e-5 -0.00022336227549328311 -4.7622549674443324e-5 -2.720022813389891e-5 -5.968394099126298e-5 5.070071064313965e-5 5.001967373447855e-5 -2.4605046848033288e-5 -8.268117663773229e-5 -0.00016701916602995532 2.0682723414767193e-5 6.074094279812839e-5 3.6909695050872585e-6 -2.68332706685541e-5 -0.000201565398230187 -0.00030435894869820817 5.919307376525539e-5 -0.00011349754923809667 -6.613780270652005e-5; 6.894806892596058e-5 6.588190035181052e-5 -5.356548399795638e-5 6.813991146188868e-6 -0.0001754253648898196 -0.00012160885353587615 -9.635873402949475e-5 9.650051194333633e-6 4.910954839652447e-5 -4.9895698342156525e-5 5.28955387998824e-6 0.00010840565255129686 -3.879746191451318e-5 0.00015994511659009442 -5.4892265782838035e-5 -0.00016180651846589653 -7.48347852124456e-5 -1.653216031918893e-5 -6.866899321103031e-5 0.00015062660849871203 0.0002472924347755868 6.797457490101375e-5 0.00010658429846152617 -0.00012163789915867242 -6.801496465705278e-5 -1.8546716549467546e-5 -0.00012419606037219046 1.4764196946950634e-6 6.784557944847759e-5 -0.00012581839519333295 2.6951830929465293e-5 -6.37297444263974e-6; -0.0001365522551890667 3.752281156319935e-5 0.0001878226099693941 4.687816836472171e-6 -0.00011081425362082391 2.0609691655941256e-5 -0.0001777137152925339 -0.00015084716904258615 1.973011205389809e-5 9.00768278557602e-5 -5.742050814607364e-5 1.0713333377303018e-5 0.0001615104771377313 -0.00010087448061259666 -8.372974878458525e-5 4.6913147908149314e-5 -1.377030855966466e-5 7.275336212895626e-5 1.3486373628061228e-5 7.704395067829435e-5 -8.012138312498299e-5 0.00013842942245954464 -0.0001318853832157534 -2.82337390484959e-5 4.3390489391263425e-5 -3.769428110846503e-5 0.00015222909098209532 9.14778353222441e-5 -5.630154953852537e-5 3.9359685270559265e-5 -8.705833932200674e-5 6.76003907033966e-5; 6.163793469353238e-5 -2.882821908530274e-5 -1.823797000940209e-5 -3.028204090912104e-5 -9.3756940343487e-5 -9.5532004790916e-5 3.4474325646110984e-5 0.00016388879331100207 5.1014916523508566e-5 1.2704832358973281e-5 4.578007966086862e-5 0.00011647980309829557 -9.6428424596857e-5 2.7405081625723717e-5 -0.00010628308345291703 -2.8459731849910345e-5 0.0001699479343257135 6.21801463308989e-5 7.48859110059383e-5 -6.6654407349841256e-6 -9.006770241500344e-5 0.00012575961764708682 1.3626136849448162e-5 6.269043106625395e-5 8.060028076909814e-5 3.700148036626745e-5 1.9224739004313867e-6 -0.00016597113980639132 7.05973233449432e-5 7.630718256306523e-7 -5.3023255616035985e-5 -6.913937076421318e-5; -0.00017198994011053003 8.337034148985306e-6 -3.815789297006924e-5 -0.0001533822605888631 0.00012462983885413916 0.000205340770651651 9.806430160074406e-6 -9.40057461885031e-5 9.239228255681534e-5 4.1281987239903874e-5 7.536222709675202e-5 7.475320034061525e-5 -9.117408172826292e-5 -0.00014074677438841166 -0.00013111986755759265 -0.00016392932867531783 -0.00012497073749911952 0.00015171972789901193 -0.00010060925252402793 0.000126607167647286 7.676034600806378e-5 -9.605349169944492e-5 0.00011591984820186111 0.00011354585602746258 2.5578400698555968e-5 6.363610463083614e-5 3.7376089483572166e-5 -0.0001365918678966642 -0.0001245022676921691 -2.5330825348655488e-5 3.834969628585766e-5 2.0366004708173285e-5; -0.00011352584235774369 -0.0002940098636148853 -0.0001418764875101219 1.8243547196656023e-5 6.512996855779813e-6 -1.6689473091356955e-5 -2.612628097758883e-5 2.614083157334066e-5 -9.303419913100498e-5 -5.215953181086932e-5 2.1187152264788494e-5 -6.927706247133837e-5 2.360306304999552e-5 7.439097705592026e-5 2.337819230397893e-5 3.240615139893227e-5 -0.00015396410865414393 -0.00013277998343747119 0.0002172155803005939 6.0796741558627227e-5 2.864701698363081e-5 0.00019429887495298062 -8.612358190057569e-5 8.098145942730921e-5 -8.466803385581282e-5 -7.629740841858078e-5 -3.691959475451672e-5 -0.00018951811919160218 0.00010151827907485867 1.2802664965919651e-6 -2.318237941416441e-5 -1.911677984516569e-5; -9.90671865105447e-5 0.000167567228800729 0.00010608356069551972 -0.00013052917665790605 2.666418231511051e-5 0.00010259347300289226 1.8157976683690418e-5 -8.787288561451528e-5 -5.1715428980412074e-5 5.688201463326802e-5 -0.00012741393174964984 -5.509288121120026e-5 -6.980637726828384e-5 1.6115422568513368e-5 -8.277318147465299e-5 -1.6466757642226783e-5 0.00019177553304016692 -0.00011531533118559089 6.17482332932994e-5 2.269362994961305e-6 -1.9137841717926455e-5 -1.3781926554108548e-5 -9.450592642130524e-6 9.933112278358999e-5 3.363808764388901e-5 -0.00027377181277133566 -2.1997329440080033e-5 -6.243743291579234e-5 7.43549208037468e-6 0.0001000117958943698 1.3472243609255124e-5 2.564440592378152e-5; -9.716315684228733e-5 -9.369601751997697e-5 -2.048959222083595e-6 -1.075668818841186e-5 0.00012251715385035034 -0.00022005083586874993 -0.0001222987673277372 -3.517768413032444e-5 5.0234323917243404e-5 -5.631300586130092e-5 -3.709966471389464e-5 6.337101809593262e-5 0.00017728950048011123 -0.0002194761661844671 -2.725946864258772e-5 8.609894144727321e-5 8.553205703763599e-5 -0.00017367523536406058 -0.00010516272299311966 6.85526495185336e-5 3.121271152893432e-5 4.540472000809649e-6 0.00017897822113853304 -2.9168926586714285e-7 2.4275167798333063e-5 -6.499044265654632e-6 0.0002657036632267367 0.00017111294006143168 2.4510881540189168e-5 -3.9071449227348265e-5 7.124858494061291e-6 1.59373896717442e-5; 9.474625907973601e-5 4.6134174723597314e-5 -0.00023721686962159132 4.783368926513329e-5 9.181839190791284e-5 3.033065472340608e-5 -5.032948543299037e-5 -4.219142506697896e-5 -0.00011894097629337446 9.382366037829849e-5 6.245944454935603e-5 5.049733285369926e-7 2.5731347292496144e-5 2.079466462095288e-5 7.979854630905967e-5 0.00020075181741635538 0.00012970171296405232 -5.707727226356771e-6 -1.1678750027908218e-5 -3.7865660900535034e-5 -2.1957076072391167e-5 7.475949455187708e-5 -2.0022457150440457e-5 5.815317267337901e-5 2.0471151898564213e-5 0.00011795174718504022 -0.00013098278073229587 -4.8749154715144596e-5 0.00011614023746590071 -2.2572127851352744e-6 0.00020199094210186487 8.719528110700381e-6; -2.6468726825355355e-7 0.0001228136181244584 7.151638526656065e-5 2.226558097126183e-5 6.343214152101092e-5 0.00017086931000183923 -0.0001961727746883444 -2.7476166745043323e-5 -6.767106135419017e-5 -6.887316596741347e-5 -1.7107175902237174e-5 -8.900802533654758e-5 -9.452735705884758e-6 -0.00015237322697695294 -0.00012206368303205087 -5.165424119961171e-5 -5.2319929475663494e-5 -0.00022294517004676795 -1.750144367442459e-5 -6.683567955670714e-5 4.79481943121614e-5 7.617387125454371e-5 0.00010712375040119173 -3.991165687959219e-5 -9.66825744607483e-5 -1.870677335584212e-5 1.779220763375945e-5 -3.865376749938078e-5 -5.8588196064096804e-5 4.082164484966783e-5 -4.345968666066285e-5 -0.0001479055125154501; 0.00015671655901217373 0.00011593683202493258 3.5736683365865256e-5 9.387878798333276e-6 0.00015109073037722316 2.128367022935476e-5 -3.813026442101028e-5 -4.422658557914746e-5 2.6932361742078514e-5 -0.00013637589390880111 3.545257177295644e-5 0.00010652278333411642 0.00011007369803239572 -7.754196221965582e-5 6.514036694981189e-5 5.7345538970302044e-5 -4.363574144106502e-5 1.385140776522223e-5 4.145785614937829e-5 7.9254902538123e-5 7.660339078840946e-5 -0.00019993183477806558 -6.989015773895023e-5 0.00010664261835602452 6.95958432506098e-5 0.00012430570395638674 7.067000742507494e-5 -1.7070348895750748e-5 -9.844369296942737e-5 -0.00010253255559395481 4.33896720906883e-6 0.00019166986648898795; -0.0001732285721398425 7.29810732711587e-5 -2.1029812460967804e-5 -0.00017606094236797035 -5.676609208228981e-6 2.837595881757066e-5 0.0001642690501004942 -0.00026247714170564673 0.0003331862123279656 6.55768769040248e-5 -6.131907624038887e-5 -1.284053844827893e-5 -0.0001258618697927764 5.624235442993724e-5 -9.533731402441801e-5 -0.00016105785955986958 -1.0386210606871449e-5 8.131686869929539e-5 0.00013746466963241428 -3.454637529174221e-5 -2.0818167586901746e-5 5.0593431905518486e-5 1.8040477397527763e-5 -2.4387488638958034e-5 -2.6894647208596443e-5 -5.547947269477867e-5 5.821388428488763e-5 -0.0001099828635513319 -9.905359668734725e-5 -0.00012038464531609734 3.066488870193239e-6 -0.0001111492577645377; -7.036141136819351e-5 8.678300405694186e-5 2.552918890972643e-5 4.01532507148068e-5 9.326366312410643e-5 3.512890002997558e-5 2.611213780610151e-7 -0.00014775398262470703 0.00013994587975859227 -5.251815122943316e-5 0.00018048358656773498 -2.165348082240755e-5 2.4809116488484087e-5 -3.2106430759143945e-5 1.099609112673683e-5 -0.00016451464575407873 -4.406073089381907e-5 2.1542526204885473e-5 -3.1395566062267664e-5 0.00012607436843057752 -6.941066198675014e-5 -4.029125006678471e-5 6.903801300171726e-5 0.00022466023261034928 -2.60377800810936e-6 0.00013418060012003536 7.805045793603767e-5 5.514945410499504e-5 2.8813199641498624e-6 -3.166135679394064e-5 -3.734989326955707e-5 -9.953844523449067e-5; 1.0228152439215116e-5 -6.448375958135959e-5 -0.00014066502213828686 8.016320911118313e-5 7.573170709072958e-5 2.9488721827495535e-5 -7.480797236204741e-5 -5.6986427318779224e-5 -2.934020582179678e-5 -0.00013606687886059913 -0.0001473992556042663 -5.4692903597621346e-5 7.607608544056694e-5 -3.195421910496919e-5 -9.931092378102564e-5 -6.403843459558407e-5 0.00016295265644438052 6.842460110199908e-5 5.695573830743166e-5 0.00017460047505445312 0.00016460127202443196 -4.602367846398066e-6 -0.0002453012379759609 -1.9600989128245753e-5 -0.00011705477978686367 -0.00012754876976252256 -8.831132751075256e-5 -3.5285348951599704e-5 0.0002007783810200531 0.0001487777958915104 5.339742759823559e-5 1.903321075368302e-5; -8.148181004690958e-5 0.00015064890831234973 5.387938508756077e-5 -4.0390715729524196e-5 4.75992332223446e-5 -0.00010612150581328977 -6.196399553920002e-5 -2.9854437888201392e-5 -7.457188403239949e-5 5.577172705778919e-5 -0.00010901028654080843 0.00014127393878988193 9.336059374726643e-5 -4.6251765160202866e-5 -7.029741900514958e-5 5.82736669468318e-5 -9.643258271179031e-6 -1.9635117131485776e-5 -8.890172975843522e-5 4.793375320507765e-6 -3.9254025613272734e-5 -9.871280562839392e-5 3.6268130781441328e-6 -6.656187688241657e-5 -3.614001761780679e-5 -4.5170754309594654e-5 6.869600692990619e-5 3.3190843533873264e-5 -7.337172209990786e-5 -9.218902380513637e-7 3.5457893683268226e-5 -0.0001419551532297235; 0.00011975422010316558 -0.00017214570336484262 -0.00027085410922274696 -3.839253641964402e-6 -4.483255044881151e-5 0.0001514803202359668 6.205753514102758e-5 7.260344635022465e-5 1.866324007886331e-5 -0.0001185755995914612 -4.581451732640109e-5 0.00012582049155633083 -0.00013230139531635542 0.0001043868810290209 9.779209129072101e-5 -6.740465797814099e-6 5.1547250658293946e-5 -1.4675872554401387e-5 1.0711160465399413e-5 9.813630956948933e-5 3.098933800963792e-5 -6.138897380999212e-5 -0.00014040806356326572 -4.1671186883216e-5 -3.373722076623527e-5 -1.600940097098186e-5 -4.6028881589622844e-5 2.5249511489780125e-5 0.0002601685896285577 -5.077108882806857e-5 -0.0001862846352390256 -2.384974569686673e-6; -8.73342923968119e-5 -2.5005097990180588e-5 4.3061414549924776e-5 5.922609039884205e-6 0.00012902126946402764 1.772225187425544e-5 0.00016536547810221111 -1.669818447577017e-5 7.055268543458798e-5 6.321326691806242e-5 -6.415398638081013e-5 -0.0002502135303019535 4.249144877223498e-5 -0.0001404141785259322 1.80791375952399e-5 -0.00017742515591988547 0.00015957269534385134 -0.00016143012725832132 -8.595098005922906e-5 -7.648530844915201e-5 -2.6302486725290568e-5 -0.00010412783557865828 -1.3333260534528653e-5 0.00011255155732771294 1.510625498314969e-5 -0.00011306515625063502 4.9764445071696925e-5 -1.0306877806109606e-5 -0.0001264801958258102 -9.301824314958028e-6 -3.978285712389018e-5 -5.75972790046519e-5; -0.0001710065380486901 1.2107378997593525e-5 6.845939332355614e-5 3.201702110322119e-5 -1.4386551820512515e-5 2.9093985167259732e-5 -2.5906233541041116e-5 -0.00013805934305334672 2.8871679196275348e-5 9.141816825808327e-5 -0.00011163550244441803 -3.301525885968288e-5 -8.27496232980198e-5 0.00016561978826734253 -0.00021769896285173823 6.787797155061191e-5 -9.446246714129055e-5 4.8598371895192225e-5 0.00017088391449736665 6.166435467979828e-5 -0.00014093954542682298 0.00010312678476997943 4.473833077776842e-6 1.113405868613472e-5 8.65730134953468e-5 2.6040429465634653e-5 2.4980877232996047e-5 6.573537665621764e-5 0.0002254934542995519 2.1777784231479783e-5 0.00017864984964920623 -4.9324579786353345e-5; 4.567550580013707e-5 -1.6482699964578773e-5 0.00013995414495348871 -0.0002758874382090583 -3.628753826723981e-5 -0.00012928111233868252 -0.00017629848140877303 -8.023814577261512e-5 7.19638280605367e-5 -6.477160629237727e-5 0.00012292546464941619 -0.00019989636190820975 -0.00022472419938657844 1.6715621069858e-5 -2.5026211249667665e-5 -4.451674091633255e-5 -5.4645921653167735e-5 -3.394552756453695e-5 -1.7251222987563525e-5 6.296229706259262e-5 -6.226101528972916e-5 -1.018077661519583e-5 -5.769692713165274e-5 7.684838213372304e-5 -8.438809733724499e-5 0.00011126585171205178 6.277896475861043e-5 -3.096812203530069e-5 -1.1174393210404765e-5 0.0001509882656426752 -2.1253325274470148e-5 0.00016191062762282884; -4.705865676252193e-5 0.00010958289904200603 -9.91450024039752e-6 6.298227499618782e-5 5.454399406707108e-5 -6.805677576076983e-5 -2.2115566916094667e-6 6.994507539467101e-5 -2.8019909493997723e-5 -6.124263045231726e-5 5.924624174630241e-5 6.514009057437196e-6 -3.44926668712185e-5 6.575422567229253e-5 5.9831370619665976e-5 0.0001567007781905771 6.562149765349407e-5 3.6296205236874356e-5 2.589543512024928e-6 -6.209839220710483e-5 -0.0001746302080372607 -0.00011206176741135209 -0.00014193128326761347 -0.00013756167781858633 -2.9823317434336197e-5 -6.04904346782062e-5 7.275838988982051e-5 -0.0001092745590521305 -0.00017524394961391883 0.00010117086613003047 -0.00014609476083742205 -0.00013355050059354057; 9.700950046073751e-6 -0.00021020508995558647 -8.49276080988293e-5 9.705398713425329e-5 -8.89038971071088e-5 0.00019871909478400335 -0.00012423429563920518 -2.4570310982983302e-5 5.056392490236205e-5 2.128694848227723e-5 -6.129310770472132e-5 0.00013128815351240783 -0.0002384074147120072 7.253553792555384e-5 0.00010191436320031675 0.00023888408338837707 -9.777497018500676e-5 -5.4337685127273076e-5 9.27665417986063e-5 -2.9588775418111538e-5 -2.3536089996802907e-5 -0.00018955852703935987 0.00012128203104669662 2.7676612604529923e-5 9.927087718218392e-6 0.00010896400215290873 -8.608676273421287e-5 4.0670667091155504e-5 6.316971428180867e-5 -5.6107550902976e-5 -0.0001584696261705764 3.0876643212138655e-7; 2.811689684365204e-5 0.00010102299803891067 -0.00015528014681278688 2.837007652176556e-5 -0.00012203435379290153 0.00018187738261674403 9.302937806519298e-7 4.15625518295035e-5 -1.2367406255557539e-5 -4.320388232683931e-5 -0.0001709254056399243 -3.838078540343501e-6 2.8504490743739638e-5 -4.5631232736399716e-5 -3.490355719206714e-5 -6.625108561222652e-5 7.05098873518868e-5 -2.6490445619334648e-5 9.038027293770885e-5 0.0001144972711850423 -6.0047382233665634e-5 0.00017742216825042693 0.00011757448466285002 -9.368562221736821e-5 -0.00016357187176565676 4.1728392731396365e-5 -8.868756320536293e-5 0.00012994792724982629 -5.1351241366701335e-5 -4.146609611241952e-6 -6.573917833831867e-5 -0.00015717528638872227; 7.376263430893607e-5 7.20204935697121e-5 0.00016850662552139354 -8.880290795962736e-5 0.00010385801771741667 9.31448348664101e-8 -9.269804821263705e-5 0.00011965494388102927 4.544118393649223e-5 2.901169329385518e-5 5.263732071456693e-6 -2.5330767932186292e-5 -2.9864742720686786e-5 -2.3066087925993013e-5 -0.00012244394554903007 0.0001262324532199855 -2.154031415733083e-5 -9.835000481137816e-5 -3.048122369539407e-5 -2.2351332410773857e-5 -3.472226115439705e-5 -0.00012839172075649474 -8.406887162157962e-5 -4.40374441331508e-5 -4.076088760660513e-5 -1.8735765372126238e-5 -9.98765779688196e-6 -5.342637985626055e-5 -2.200216102385304e-5 0.00010005855630705973 4.308615654560655e-5 -0.00022177466406167768; 4.2568029629012036e-5 5.7516917466830626e-5 5.168275537121564e-6 -0.00012657470911234213 7.546453302767359e-5 -3.0038065643918293e-5 -7.026115458883125e-5 0.00014579815938179494 6.361896844157486e-5 -7.878562287386324e-5 0.0001766512028085453 1.5168960588756219e-5 6.212489420123686e-5 -8.662742443574241e-5 0.00012255103824701733 -0.00017744571637321138 8.372877668863491e-5 0.0002037972185904546 -3.710355708915889e-5 -9.869117316282897e-5 -0.00017382038318045204 -6.240305488189496e-5 -7.31851655031159e-5 0.00015506966478698127 -4.589814535743823e-6 8.256130563369165e-5 -0.00015783727253744883 2.7701372845160204e-5 -1.2133929386782432e-5 -0.00018456997027111848 -0.00018564331953837792 -3.143424005555414e-5], bias = [-7.077877976144833e-10, 2.2658209375401335e-10, 1.895150205598621e-9, 7.914322432827598e-10, -2.3155219693020318e-9, 2.936843847972799e-9, 1.6422780131962876e-9, -1.8919672780179916e-9, -4.040986269139523e-9, -8.626586398224321e-10, 7.969045828037919e-11, 1.1949920341182442e-9, -6.939553947035184e-10, -2.619094997934088e-9, -7.526063178898716e-10, 4.7590263591424e-10, 3.0245668743182823e-9, -3.1057633525056182e-9, 3.142212270954802e-9, -2.267671143843129e-9, 1.9154449008983554e-9, -7.671245999165557e-10, -1.845909366678581e-9, -5.91645696691385e-10, -2.544760613005634e-9, 1.6470855792651397e-9, -2.275799669656608e-9, -2.1903341918569904e-9, -5.207983738838187e-10, -7.921548115607132e-10, -1.2204345111350288e-9, -9.86302791740207e-10]), layer_4 = (weight = [-0.0006213388552397448 -0.0005852080835998407 -0.0007366882718395276 -0.0006638187867284822 -0.0006454110781190602 -0.0008320923085714789 -0.0006372297159005324 -0.000793835663527801 -0.0005006796816182057 -0.0006611223876407968 -0.0008636036052724226 -0.0007270284598566308 -0.0007333503348887806 -0.0006122408381893225 -0.000707179688117624 -0.0007355868907320488 -0.0005744348120905261 -0.0008147111875647709 -0.0006389813482977939 -0.0009236124373254438 -0.0005304279035905962 -0.0008519208027745623 -0.0006511301721029659 -0.0006110807776763492 -0.0007076413562766184 -0.0006657167280090714 -0.0008433324337438541 -0.0008168460077245483 -0.000650385606198951 -0.0005759624592717979 -0.0005797123508444127 -0.0007233712809003027; 0.00044889610096692087 0.0003661254514817301 0.0002593839904216043 0.0002926990395906213 0.0002092977591556634 0.0003126382699822454 0.00011441932108943487 0.0002956754894445875 0.00019772717694069324 0.00026657856173489195 0.00018318603866429416 0.00037225327933478015 0.00025794510364980666 0.00013471509629861797 0.0003537598258135786 0.00026354055202984153 4.05773223403231e-6 0.00022097530775401614 0.00013393010290132143 8.314059367308108e-5 0.000140162437616706 0.0002967806225719931 6.457033239081356e-5 0.000262327366966068 0.00018456033273684817 0.00019541087138822585 0.00012783458296557605 0.0002216247887118604 0.00028720816860510476 0.00035043842363551047 0.00023918194190362615 0.00031205322541316504], bias = [-0.0006755511985572984, 0.00024224979972483265]))
Visualizing the Results
Let us now plot the loss over time
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel = "Iteration", ylabel = "Loss")
lines!(ax, losses; linewidth = 4, alpha = 0.75)
scatter!(ax, 1:length(losses), losses; marker = :circle, markersize = 12, strokewidth = 2)
fig
end
Finally let us visualize the results
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p = res.u, saveat = tsteps, dt, adaptive = false))
waveform_nn_trained = first(
compute_waveform(
dt_data, soln_nn, mass_ratio, ode_model_params
)
)
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel = "Time", ylabel = "Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth = 2, alpha = 0.75)
s1 = scatter!(
ax, tsteps, waveform; marker = :circle, alpha = 0.5, strokewidth = 2, markersize = 12
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth = 2, alpha = 0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker = :circle, alpha = 0.5, strokewidth = 2, markersize = 12
)
l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth = 2, alpha = 0.75)
s3 = scatter!(
ax, tsteps, waveform_nn_trained; marker = :circle,
alpha = 0.5, strokewidth = 2, markersize = 12
)
axislegend(
ax, [[l1, s1], [l2, s2], [l3, s3]],
["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
position = :lb
)
fig
end
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 128 × AMD EPYC 7502 32-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 128 default, 0 interactive, 64 GC (on 128 virtual cores)
Environment:
JULIA_CPU_THREADS = 128
JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 128
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.