Training a Neural ODE to Model Gravitational Waveforms
This code is adapted from Astroinformatics/ScientificMachineLearning
The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl
Package Imports
using Lux,
ComponentArrays,
LineSearches,
OrdinaryDiffEqLowOrderRK,
Optimization,
OptimizationOptimJL,
Printf,
Random,
SciMLSensitivity
using CairoMakie
Precompiling SciMLSensitivity...
198098.3 ms ✓ Enzyme
7108.1 ms ✓ Enzyme → EnzymeGPUArraysCoreExt
7127.1 ms ✓ FastPower → FastPowerEnzymeExt
7183.7 ms ✓ QuadGK → QuadGKEnzymeExt
7201.3 ms ✓ Enzyme → EnzymeLogExpFunctionsExt
7197.9 ms ✓ DifferentiationInterface → DifferentiationInterfaceEnzymeExt
7386.7 ms ✓ Enzyme → EnzymeSpecialFunctionsExt
19019.7 ms ✓ Enzyme → EnzymeChainRulesCoreExt
19092.1 ms ✓ Enzyme → EnzymeStaticArraysExt
19696.2 ms ✓ DiffEqBase → DiffEqBaseEnzymeExt
28974.8 ms ✓ SciMLSensitivity
11 dependencies successfully precompiled in 247 seconds. 269 already precompiled.
Precompiling PreallocationToolsSparseConnectivityTracerExt...
936.7 ms ✓ PreallocationTools → PreallocationToolsSparseConnectivityTracerExt
1 dependency successfully precompiled in 1 seconds. 38 already precompiled.
Precompiling LuxLibEnzymeExt...
1235.8 ms ✓ LuxLib → LuxLibEnzymeExt
1 dependency successfully precompiled in 1 seconds. 132 already precompiled.
Precompiling LuxEnzymeExt...
7416.9 ms ✓ Lux → LuxEnzymeExt
1 dependency successfully precompiled in 8 seconds. 148 already precompiled.
Precompiling OptimizationEnzymeExt...
21171.6 ms ✓ OptimizationBase → OptimizationEnzymeExt
1 dependency successfully precompiled in 21 seconds. 114 already precompiled.
Precompiling MLDataDevicesTrackerExt...
1130.4 ms ✓ MLDataDevices → MLDataDevicesTrackerExt
1 dependency successfully precompiled in 1 seconds. 58 already precompiled.
Precompiling LuxLibTrackerExt...
955.5 ms ✓ LuxCore → LuxCoreArrayInterfaceTrackerExt
3152.1 ms ✓ LuxLib → LuxLibTrackerExt
2 dependencies successfully precompiled in 3 seconds. 96 already precompiled.
Precompiling LuxTrackerExt...
1830.2 ms ✓ Lux → LuxTrackerExt
1 dependency successfully precompiled in 2 seconds. 110 already precompiled.
Precompiling ComponentArraysTrackerExt...
1099.2 ms ✓ ComponentArrays → ComponentArraysTrackerExt
1 dependency successfully precompiled in 1 seconds. 69 already precompiled.
Precompiling MLDataDevicesReverseDiffExt...
3495.9 ms ✓ MLDataDevices → MLDataDevicesReverseDiffExt
1 dependency successfully precompiled in 4 seconds. 43 already precompiled.
Precompiling LuxLibReverseDiffExt...
3394.0 ms ✓ LuxCore → LuxCoreArrayInterfaceReverseDiffExt
4214.9 ms ✓ LuxLib → LuxLibReverseDiffExt
2 dependencies successfully precompiled in 4 seconds. 94 already precompiled.
Precompiling ComponentArraysReverseDiffExt...
3516.3 ms ✓ ComponentArrays → ComponentArraysReverseDiffExt
1 dependency successfully precompiled in 4 seconds. 51 already precompiled.
Precompiling OptimizationReverseDiffExt...
3407.0 ms ✓ OptimizationBase → OptimizationReverseDiffExt
1 dependency successfully precompiled in 4 seconds. 121 already precompiled.
Precompiling LuxReverseDiffExt...
4238.5 ms ✓ Lux → LuxReverseDiffExt
1 dependency successfully precompiled in 4 seconds. 111 already precompiled.
Precompiling MLDataDevicesZygoteExt...
1493.1 ms ✓ MLDataDevices → MLDataDevicesZygoteExt
1 dependency successfully precompiled in 2 seconds. 71 already precompiled.
Precompiling LuxZygoteExt...
2525.9 ms ✓ Lux → LuxZygoteExt
1 dependency successfully precompiled in 3 seconds. 143 already precompiled.
Precompiling ComponentArraysZygoteExt...
1454.6 ms ✓ ComponentArrays → ComponentArraysZygoteExt
1 dependency successfully precompiled in 2 seconds. 77 already precompiled.
Precompiling OptimizationZygoteExt...
2131.5 ms ✓ OptimizationBase → OptimizationZygoteExt
1 dependency successfully precompiled in 2 seconds. 128 already precompiled.
Precompiling CairoMakie...
551.4 ms ✓ RangeArrays
595.0 ms ✓ IndirectArrays
589.7 ms ✓ PolygonOps
628.2 ms ✓ GeoFormatTypes
651.2 ms ✓ Contour
680.3 ms ✓ TriplotBase
691.8 ms ✓ TensorCore
701.5 ms ✓ StableRNGs
746.9 ms ✓ Observables
749.3 ms ✓ RoundingEmulator
753.2 ms ✓ Extents
786.8 ms ✓ IntervalSets
867.0 ms ✓ IterTools
616.2 ms ✓ PCRE2_jll
500.3 ms ✓ CRC32c
627.9 ms ✓ LazyModules
606.9 ms ✓ Ratios
1297.1 ms ✓ Grisu
619.2 ms ✓ MappedArrays
735.6 ms ✓ Inflate
877.8 ms ✓ TranscodingStreams
1514.6 ms ✓ OffsetArrays
666.0 ms ✓ RelocatableFolders
1667.0 ms ✓ Format
996.6 ms ✓ WoodburyMatrices
1012.2 ms ✓ SharedArrays
944.9 ms ✓ OpenSSL_jll
874.6 ms ✓ Libmount_jll
943.3 ms ✓ Graphite2_jll
879.3 ms ✓ LLVMOpenMP_jll
888.8 ms ✓ Bzip2_jll
853.6 ms ✓ Xorg_libXau_jll
893.2 ms ✓ libpng_jll
864.4 ms ✓ libfdk_aac_jll
896.0 ms ✓ Imath_jll
899.7 ms ✓ Giflib_jll
1594.4 ms ✓ SimpleTraits
2233.9 ms ✓ AdaptivePredicates
880.9 ms ✓ LERC_jll
880.5 ms ✓ EarCut_jll
947.6 ms ✓ LAME_jll
895.6 ms ✓ CRlibm_jll
872.5 ms ✓ Ogg_jll
993.5 ms ✓ XZ_jll
1050.8 ms ✓ JpegTurbo_jll
891.1 ms ✓ x265_jll
967.7 ms ✓ x264_jll
2428.1 ms ✓ UnicodeFun
998.3 ms ✓ Xorg_libXdmcp_jll
1074.5 ms ✓ Zstd_jll
1125.8 ms ✓ libaom_jll
839.8 ms ✓ Xorg_xtrans_jll
920.6 ms ✓ LZO_jll
1016.7 ms ✓ Expat_jll
3386.8 ms ✓ FixedPointNumbers
1025.2 ms ✓ Opus_jll
956.3 ms ✓ Libffi_jll
967.0 ms ✓ isoband_jll
1057.3 ms ✓ Libiconv_jll
1011.1 ms ✓ FFTW_jll
940.1 ms ✓ Libuuid_jll
962.3 ms ✓ FriBidi_jll
603.9 ms ✓ IntervalSets → IntervalSetsRandomExt
585.1 ms ✓ IntervalSets → IntervalSetsStatisticsExt
1086.3 ms ✓ OpenBLASConsistentFPCSR_jll
602.7 ms ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
645.2 ms ✓ OffsetArrays → OffsetArraysAdaptExt
671.9 ms ✓ Showoff
676.9 ms ✓ StackViews
724.1 ms ✓ PaddedViews
1663.6 ms ✓ FilePathsBase
1665.0 ms ✓ GeoInterface
1067.7 ms ✓ FreeType2_jll
1223.3 ms ✓ AxisAlgorithms
1139.9 ms ✓ Pixman_jll
4939.8 ms ✓ Test
1148.9 ms ✓ libvorbis_jll
1335.5 ms ✓ OpenEXR_jll
682.8 ms ✓ Ratios → RatiosFixedPointNumbersExt
1198.8 ms ✓ Libtiff_jll
1336.0 ms ✓ libsixel_jll
717.2 ms ✓ Isoband
669.6 ms ✓ MosaicViews
1041.8 ms ✓ XML2_jll
779.0 ms ✓ FilePathsBase → FilePathsBaseMmapExt
1171.9 ms ✓ AxisArrays
1158.9 ms ✓ FilePaths
2476.8 ms ✓ Xorg_libxcb_jll
2176.2 ms ✓ ColorTypes
576.6 ms ✓ SignedDistanceFields
1254.4 ms ✓ Fontconfig_jll
1461.1 ms ✓ FreeType
935.8 ms ✓ InverseFunctions → InverseFunctionsTestExt
786.0 ms ✓ ColorTypes → StyledStringsExt
1071.3 ms ✓ Gettext_jll
1079.1 ms ✓ Xorg_libX11_jll
2160.5 ms ✓ AbstractFFTs → AbstractFFTsTestExt
2105.9 ms ✓ FilePathsBase → FilePathsBaseTestExt
2296.1 ms ✓ Distributions → DistributionsTestExt
5726.2 ms ✓ PkgVersion
1177.7 ms ✓ Xorg_libXrender_jll
5887.8 ms ✓ FileIO
1308.2 ms ✓ Xorg_libXext_jll
3525.7 ms ✓ Interpolations
1740.0 ms ✓ Glib_jll
4911.9 ms ✓ IntervalArithmetic
3288.6 ms ✓ ColorVectorSpace
1042.6 ms ✓ Libglvnd_jll
1046.3 ms ✓ Cairo_jll
725.6 ms ✓ ColorVectorSpace → SpecialFunctionsExt
812.2 ms ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
2241.8 ms ✓ QOI
1171.1 ms ✓ libwebp_jll
908.4 ms ✓ HarfBuzz_jll
1556.5 ms ✓ Pango_jll
1597.8 ms ✓ libass_jll
10395.2 ms ✓ SIMD
8907.7 ms ✓ FFTW
6941.7 ms ✓ Colors
644.7 ms ✓ Graphics
1635.8 ms ✓ FFMPEG_jll
862.1 ms ✓ Animations
1758.5 ms ✓ ColorBrewer
5709.2 ms ✓ ExactPredicates
2114.0 ms ✓ OpenEXR
2277.2 ms ✓ KernelDensity
1719.4 ms ✓ Cairo
3794.1 ms ✓ ColorSchemes
14544.8 ms ✓ GeometryBasics
1837.0 ms ✓ Packing
2386.6 ms ✓ ShaderAbstractions
23203.2 ms ✓ Unitful
3213.2 ms ✓ FreeTypeAbstraction
7441.5 ms ✓ DelaunayTriangulation
667.3 ms ✓ Unitful → InverseFunctionsUnitfulExt
685.8 ms ✓ Unitful → ConstructionBaseUnitfulExt
2164.3 ms ✓ Interpolations → InterpolationsUnitfulExt
5613.2 ms ✓ MakieCore
12559.3 ms ✓ Automa
7337.5 ms ✓ GridLayoutBase
11405.7 ms ✓ PlotUtils
18865.9 ms ✓ ImageCore
2095.2 ms ✓ ImageBase
2543.3 ms ✓ WebP
3521.3 ms ✓ PNGFiles
3563.7 ms ✓ JpegTurbo
2188.0 ms ✓ ImageAxes
4701.3 ms ✓ Sixel
1626.6 ms ✓ ImageMetadata
13072.6 ms ✓ MathTeXEngine
1868.4 ms ✓ Netpbm
52256.6 ms ✓ TiffImages
1210.5 ms ✓ ImageIO
114527.6 ms ✓ Makie
75519.5 ms ✓ CairoMakie
155 dependencies successfully precompiled in 258 seconds. 114 already precompiled.
Precompiling SparseMatrixColoringsColorsExt...
873.0 ms ✓ SparseMatrixColorings → SparseMatrixColoringsColorsExt
1 dependency successfully precompiled in 1 seconds. 12 already precompiled.
Precompiling ZygoteColorsExt...
1675.0 ms ✓ Zygote → ZygoteColorsExt
1 dependency successfully precompiled in 2 seconds. 69 already precompiled.
Precompiling IntervalSetsExt...
826.9 ms ✓ Accessors → IntervalSetsExt
1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling IntervalSetsRecipesBaseExt...
572.8 ms ✓ IntervalSets → IntervalSetsRecipesBaseExt
1 dependency successfully precompiled in 1 seconds. 9 already precompiled.
Precompiling StaticArrayInterfaceOffsetArraysExt...
458.0 ms ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling UnitfulExt...
598.1 ms ✓ Accessors → UnitfulExt
1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling DiffEqBaseUnitfulExt...
1392.7 ms ✓ DiffEqBase → DiffEqBaseUnitfulExt
1 dependency successfully precompiled in 2 seconds. 97 already precompiled.
Precompiling NNlibFFTWExt...
912.9 ms ✓ NNlib → NNlibFFTWExt
1 dependency successfully precompiled in 1 seconds. 56 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
559.3 ms ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
651.8 ms ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
2 dependencies successfully precompiled in 1 seconds. 43 already precompiled.
Precompiling IntervalArithmeticRecipesBaseExt...
852.5 ms ✓ IntervalArithmetic → IntervalArithmeticRecipesBaseExt
1 dependency successfully precompiled in 1 seconds. 33 already precompiled.
Precompiling SciMLBaseMakieExt...
615.8 ms ✓ Accessors → TestExt
9195.4 ms ✓ SciMLBase → SciMLBaseMakieExt
2 dependencies successfully precompiled in 10 seconds. 305 already precompiled.
Define some Utility Functions
Tip
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector
function one2two(path, m₁, m₂)
M = m₁ + m₂
r₁ = m₂ / M .* path
r₂ = -m₁ / M .* path
return r₁, r₂
end
one2two (generic function with 1 method)
Next we define a function to perform the change of variables:
@views function soln2orbit(soln, model_params=nothing)
@assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
if size(soln, 1) == 2
χ = soln[1, :]
ϕ = soln[2, :]
@assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
p, M, e = model_params
else
χ = soln[1, :]
ϕ = soln[2, :]
p = soln[3, :]
e = soln[4, :]
end
r = p ./ (1 .+ e .* cos.(χ))
x = r .* cos.(ϕ)
y = r .* sin.(ϕ)
orbit = vcat(x', y')
return orbit
end
soln2orbit (generic function with 2 methods)
This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d_dt(v::AbstractVector, dt)
a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
b = (v[3:end] .- v[1:(end - 2)]) / 2
c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
return [a; b; c] / dt
end
d_dt (generic function with 1 method)
This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d2_dt2(v::AbstractVector, dt)
a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)
Now we define a function to compute the trace-free moment tensor from the orbit
function orbit2tensor(orbit, component, mass=1)
x = orbit[1, :]
y = orbit[2, :]
Ixx = x .^ 2
Iyy = y .^ 2
Ixy = x .* y
trace = Ixx .+ Iyy
if component[1] == 1 && component[2] == 1
tmp = Ixx .- trace ./ 3
elseif component[1] == 2 && component[2] == 2
tmp = Iyy .- trace ./ 3
else
tmp = Ixy
end
return mass .* tmp
end
function h_22_quadrupole_components(dt, orbit, component, mass=1)
mtensor = orbit2tensor(orbit, component, mass)
mtensor_ddot = d2_dt2(mtensor, dt)
return 2 * mtensor_ddot
end
function h_22_quadrupole(dt, orbit, mass=1)
h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
return h11, h12, h22
end
function h_22_strain_one_body(dt::T, orbit) where {T}
h11, h12, h22 = h_22_quadrupole(dt, orbit)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
h11 = h11_1 + h11_2
h12 = h12_1 + h12_2
h22 = h22_1 + h22_2
return h11, h12, h22
end
function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
# compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2
@assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"
h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
@assert mass_ratio ≤ 1 "mass_ratio must be <= 1"
@assert mass_ratio ≥ 0 "mass_ratio must be non-negative"
orbit = soln2orbit(soln, model_params)
if mass_ratio > 0
m₂ = inv(T(1) + mass_ratio)
m₁ = mass_ratio * m₂
orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
return waveform
end
compute_waveform (generic function with 2 methods)
Simulating the True Model
RelativisticOrbitModel
defines system of odes which describes motion of point like particle in schwarzschild background, uses
where,
function RelativisticOrbitModel(u, (p, M, e), t)
χ, ϕ = u
numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
denom = sqrt((p - 2)^2 - 4 * e^2)
χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = numer / (M * (p^(3 / 2)) * denom)
return [χ̇, ϕ̇]
end
mass_ratio = 0.0 # test particle
u0 = Float64[π, 0.0] # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4) # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize) # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e
Let's simulate the true model and plot the results using OrdinaryDiffEq.jl
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)
axislegend(ax, [[l, s]], ["Waveform Data"])
fig
end
Defiing a Neural Network Model
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model
that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but incase you do use them, make sure to mark them as const
.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl
,
const nn = Chain(
Base.Fix1(fast_activation, cos),
Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[0.0002361222; -2.5134f-5; -3.2675198f-5; -2.0179814f-5; 0.0001570889; 0.00015437776; -2.973885f-5; -4.1926007f-5; -5.9189726f-5; 0.00013178472; 0.00017165627; 0.0001452623; 6.0896182f-5; 9.909419f-5; 0.00010964333; -9.503453f-5; -4.332043f-5; -8.9827794f-5; -2.2143924f-5; 9.249348f-5; 9.0252564f-5; 0.00021267831; 2.567854f-5; 4.8526872f-5; -0.0001734693; 7.3258916f-6; -8.7382556f-5; -7.425332f-5; -7.747681f-5; -9.708107f-5; -1.08455f-5; -3.460181f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-2.7915888f-5 3.5511195f-5 8.9104855f-5 -2.8891052f-6 5.811461f-5 2.5631522f-5 8.715139f-5 -4.124993f-5 5.750969f-5 -8.51571f-5 -0.000115250936 8.158131f-5 0.00020505626 -3.5645855f-5 8.627277f-5 -1.9687243f-6 -0.00017245553 0.00013843083 0.0001270358 0.00018201585 -1.617183f-5 -4.6314555f-5 5.7706344f-5 -2.9319162f-5 3.099782f-5 0.0001520169 9.670193f-5 -9.770254f-6 1.1992694f-5 6.553098f-5 4.1148945f-5 -0.00012879542; -8.610904f-5 -0.00019770063 3.7179456f-5 4.659744f-5 2.1355803f-5 -0.00012842457 -7.9891375f-5 -2.9178186f-6 2.5747462f-5 2.506685f-5 -2.9496855f-6 -9.4762465f-5 0.0001371297 -2.4802857f-5 -5.748709f-5 0.0001477618 -7.115356f-5 2.4082015f-5 -0.00020688615 -4.6545138f-5 -7.415638f-5 -0.00015130342 9.934971f-5 5.4613327f-5 1.7357901f-5 -0.00014336102 7.5193544f-5 0.000120920275 7.9174344f-5 -0.0001292656 0.00010450786 -9.47647f-5; 4.4163564f-5 -0.00012285181 -5.6034383f-5 0.00015915513 -3.1327207f-5 -0.000100794896 -2.5544407f-5 0.00014469956 -0.00012931006 -0.00014411108 7.7781944f-5 -0.00012879145 -8.7031854f-5 2.0010893f-5 2.4475581f-5 -0.00017886765 -3.1483692f-6 -3.307346f-5 5.509123f-5 1.3131704f-5 0.0001320358 -7.894841f-5 3.6647383f-5 0.00010402166 -0.00016974934 -6.728079f-5 1.19747265f-5 0.000116251504 -7.3046394f-5 5.8168247f-5 0.00026521442 -3.301653f-5; 1.3494045f-5 -2.488514f-5 4.980702f-5 3.4708755f-6 -4.5798224f-5 0.0001612265 -1.7529335f-5 -4.794299f-5 -1.3023313f-5 -9.690368f-5 5.833893f-5 -0.0001195757 5.8559485f-6 6.1913474f-6 -5.507409f-5 0.00019769612 0.00011218373 0.00012867576 8.880354f-5 -4.4879875f-5 4.8192735f-5 2.0955022f-5 -3.851236f-5 -8.964669f-5 -7.698721f-5 -2.1474632f-5 -0.00010256777 -3.843375f-5 9.593246f-5 4.6476733f-5 3.6396974f-5 0.00022579133; -0.00017359626 -3.2854914f-5 0.00014090298 0.00011346623 -0.000114624425 0.0001053746 9.517806f-5 0.00016430115 9.303932f-6 -2.2757587f-5 -6.1336585f-5 -1.8537174f-5 0.0001891091 -6.931652f-6 3.5219186f-5 0.00010928011 0.00014507711 6.332895f-5 -0.00015731029 -8.01182f-5 -0.00026497667 0.00013176228 -0.0001661617 -8.2603496f-5 0.00028707343 -0.00018143933 -0.0001846891 8.408083f-5 -4.53116f-5 -8.82985f-5 -0.00018447799 -8.7376764f-5; 0.00012059781 0.00018060365 -3.367813f-5 -7.490378f-5 -5.7452766f-5 -0.00016685201 6.8239526f-5 8.242166f-5 6.806847f-5 -0.00011807365 7.761778f-5 -1.6270988f-5 -6.2007464f-5 -9.073802f-5 -9.765056f-5 -0.00016809734 -8.681838f-5 8.397881f-5 5.1303035f-5 0.00013925444 6.0156362f-5 -3.4488999f-6 5.21196f-6 -0.000121077945 2.6298609f-5 -4.146305f-5 6.845175f-5 -0.00026472597 9.609115f-5 0.00019502986 0.00014938705 0.00010158729; 7.5140604f-5 7.870539f-5 2.9853556f-5 0.00010356099 -1.9474115f-5 -5.0085473f-5 7.340773f-5 6.2906176f-5 0.00011339688 -2.5326215f-5 9.211876f-6 -4.96269f-5 -3.927761f-6 4.1704854f-5 0.00010166293 -2.0571277f-5 5.249847f-5 -8.314025f-6 0.000119446224 -4.0616338f-5 -5.107084f-5 -6.0846352f-5 9.6505384f-5 0.000101574624 9.941655f-5 -6.840077f-5 -0.0002195551 7.657751f-5 4.7453956f-5 -7.857454f-5 0.00015462984 8.863819f-6; -6.8624635f-5 2.9798073f-5 4.363027f-5 -0.00022972412 7.898907f-5 -1.5721644f-5 -9.4029536f-5 -0.00013569614 -9.275661f-5 1.0237511f-5 -9.728912f-5 0.00017152842 9.0329195f-5 5.793042f-5 -9.3524f-6 0.00011568701 -4.3017448f-5 -9.176166f-6 9.9594545f-5 5.590554f-5 4.7252295f-5 -3.1994037f-5 -0.00012075508 -1.9649113f-5 0.00010845398 -1.7857994f-5 -2.5873645f-5 8.56621f-5 -2.609151f-5 3.3313243f-6 3.947959f-5 -4.5554836f-5; -0.00016778156 5.4447533f-5 -3.7622638f-5 2.098699f-5 2.9917735f-5 2.3668288f-5 8.9419475f-5 -7.916882f-5 6.525273f-6 -2.2046443f-6 2.742545f-5 6.791606f-5 5.8070567f-5 4.08499f-5 -0.00017214038 -0.00021774587 9.907674f-5 4.0976774f-5 7.8775025f-5 3.8604623f-5 7.506175f-5 0.00016945164 -6.246473f-5 4.051922f-5 -7.947306f-6 -0.00029482867 3.949041f-5 -0.00012470024 1.9230605f-5 6.2980434f-5 0.00012584023 -7.785297f-5; 3.6950627f-5 7.337274f-5 7.12439f-5 -1.1048756f-5 -8.7946886f-5 -7.178f-5 -4.9907f-6 -0.00017919371 -5.6085344f-5 -0.00011781649 0.00011107205 -1.132345f-5 2.0756372f-5 -0.00010107211 -0.00015456123 8.263916f-5 -4.2627926f-5 3.7501686f-5 -8.614796f-6 9.517988f-5 -5.237693f-5 0.00011506425 -1.9983947f-5 -2.758004f-5 -0.0001453524 9.77504f-5 2.4776262f-5 -0.00012800166 3.5441208f-5 3.0398458f-5 5.2607807f-5 7.068679f-6; 0.00016124033 -0.00012237926 9.071015f-5 -2.6417501f-5 0.00014193007 -8.2850594f-5 -6.190396f-5 -9.942187f-6 -6.675018f-5 1.6410926f-5 7.37853f-5 -0.00010713943 -0.000115375195 -0.00018211773 4.7780126f-5 0.0001759735 -3.667793f-5 8.231839f-6 0.00013144255 -9.005105f-5 -3.417455f-5 -1.0600637f-5 -3.2416654f-5 1.7020524f-5 -1.2918826f-5 -0.00013584546 -0.00016282062 2.3714094f-5 2.9067758f-5 1.4404303f-7 0.00014221671 3.832412f-5; -3.637279f-5 0.0001943953 -0.00014541064 5.5178945f-5 -1.9256662f-5 0.000121446814 -1.7393455f-5 -0.00010971666 0.00011628162 8.658971f-5 -0.0002987805 0.00010045032 7.0158526f-6 -0.0001415634 -2.5126956f-5 -0.000103071485 5.9701626f-5 -0.00012412896 8.732552f-5 5.7503454f-5 7.6525626f-5 -0.0001041007 -0.0001393819 -4.2993324f-5 0.00027975457 -2.4566387f-5 -1.720288f-5 -7.9605474f-5 -7.103173f-5 -0.00014223249 1.95715f-5 0.00014473662; 6.481876f-5 0.00010231922 -8.518757f-5 2.4167923f-5 5.9685735f-6 -2.6386375f-5 9.2642535f-5 1.08782315f-5 5.2189287f-5 -0.0001007506 -8.302223f-5 -9.112518f-5 9.214589f-5 -0.00010145596 9.482199f-5 -3.249801f-5 -9.596707f-5 4.3473505f-5 -0.0002647379 -8.278628f-6 0.00014815254 2.6284872f-5 -6.796181f-5 7.3062125f-5 0.00017351593 3.883932f-5 0.00010524824 -0.00014161173 7.729897f-5 7.2608964f-5 -0.0001187699 -6.744408f-5; -4.7210036f-5 0.00013004245 -6.490219f-5 -8.382727f-5 4.398087f-5 0.00010205249 0.00011588105 -0.00015620032 1.8599254f-5 5.2626507f-5 0.00010784342 2.6564308f-5 -0.000119498814 7.850133f-5 5.0327733f-5 -1.658872f-5 7.2451556f-5 -4.5792065f-5 8.3059844f-5 -4.463871f-5 -0.00012977439 6.955946f-5 0.00012434054 -5.9460348f-5 0.00016210048 -0.00010172587 3.4675515f-5 3.68331f-5 -0.00016270831 -0.000110114626 3.1957068f-5 -0.0002170752; -5.0021987f-5 6.6132256f-5 -6.9150185f-5 -8.288825f-5 9.445658f-5 7.554753f-5 8.527512f-5 -0.00020426579 1.2847983f-6 -3.7606682f-5 -6.889139f-5 -2.0401187f-5 -7.979159f-5 0.0001240794 5.9412127f-5 -1.0826698f-5 -2.2536746f-5 0.00014073904 1.722978f-5 0.00011617961 0.00011073666 -0.00012263353 3.6408688f-5 -8.8177505f-5 6.713818f-6 0.0001408911 -6.928895f-5 -2.823268f-5 -7.730958f-6 -0.0001115849 -0.00012672544 -9.382722f-5; 0.00038668094 9.486613f-5 1.3414758f-5 -5.0058083f-5 3.6719928f-5 0.00020218824 7.775814f-5 9.696814f-5 -0.00016475799 -0.0002236384 2.1059195f-5 7.063974f-5 -8.807824f-5 0.000119855664 6.9806556f-5 0.000105999075 5.5297427f-5 0.00019347288 -7.411329f-6 5.7140423f-5 -1.8955405f-5 -3.5481014f-5 0.00011347635 0.00010363136 4.4848824f-5 -3.7729208f-6 4.1802778f-5 7.779408f-5 8.842283f-5 4.3158958f-5 -3.8810056f-5 2.7845525f-5; 5.5745575f-5 9.654892f-5 5.8981823f-5 -2.8592513f-5 -3.0328194f-5 -2.5293575f-5 -0.00011452007 -2.1802896f-5 0.00015079121 -4.1553656f-5 4.9867253f-5 -9.040948f-5 -5.5818487f-5 7.651965f-5 -0.00012444913 -5.4044773f-5 -6.203776f-5 0.00013067616 -7.285283f-7 2.5754753f-5 8.289806f-5 0.00024851918 -0.00016195096 0.00015914855 -0.0001088198 -9.478999f-5 5.2210376f-5 2.7042618f-6 0.00023685017 7.2821145f-5 -5.07222f-5 -9.497658f-5; -8.18691f-6 3.9873157f-6 -3.3522214f-5 4.563704f-5 4.0928408f-5 -7.499393f-5 7.367218f-5 0.00020765781 -2.6922418f-5 0.00023510643 3.9665014f-5 -0.0001410179 3.514378f-5 -3.6537695f-5 8.722132f-5 1.8940662f-5 -7.9223955f-6 8.561552f-5 -4.7429778f-5 -9.8267155f-5 7.513149f-5 -6.447978f-5 0.00013883102 -4.744844f-5 -0.00024664364 -6.5173376f-6 7.997081f-6 -0.00015884145 -6.945602f-5 0.00019309105 6.159802f-5 -2.3563787f-5; 5.3372623f-5 2.0027539f-5 0.00015552065 -0.00010947505 -6.139094f-5 -1.5764379f-5 3.5049557f-5 -3.5985395f-5 4.6766996f-5 0.00018326462 -1.1142395f-5 0.00012304942 -1.821902f-5 0.00016774682 -1.3213462f-6 6.3038984f-5 0.00012605047 -8.332284f-5 0.00022602057 -0.00012953405 -4.3871925f-5 1.307643f-5 1.754203f-5 2.780569f-5 3.0538315f-5 -0.00016387088 -2.861262f-5 -2.9269238f-5 3.9303195f-5 0.00014794762 1.9544055f-5 -6.246997f-5; 0.00014737871 -5.899954f-7 3.601541f-5 -8.5514046f-5 -9.624486f-5 5.595911f-5 -0.00019799864 -8.288058f-6 1.1984975f-5 -4.2671505f-5 5.2455267f-5 -0.00014219781 -4.7178404f-5 -1.16819865f-5 -7.386424f-5 3.6886628f-5 -3.6987134f-5 0.00014896043 -2.6523852f-5 4.76111f-5 8.6654065f-5 8.918151f-5 -8.208701f-5 -5.7005727f-5 -6.5465483f-6 -5.19074f-5 -0.00016499593 -5.6682635f-5 2.4227613f-5 -0.00012473024 -3.6452897f-5 -0.0002499558; -3.9147213f-5 0.00011490526 7.8342375f-5 3.4393343f-5 -4.00213f-5 -6.317536f-5 -0.00015996752 -7.3585084f-5 0.00017872333 -0.00016248394 -8.599198f-6 -2.795754f-5 -7.054555f-5 -4.403835f-5 -0.00018190248 8.483431f-5 0.00010539003 3.304799f-5 9.735199f-5 0.00017585585 6.332364f-5 2.816544f-5 -8.908384f-6 5.3850854f-5 -7.3398034f-5 -6.264762f-5 -8.240978f-5 0.00022895931 -8.410094f-5 -2.2150114f-5 -6.119298f-5 3.9242143f-5; 6.506026f-5 -7.585383f-5 0.00010221181 -0.00016051914 -4.251888f-5 -8.340389f-5 -0.0002748588 8.682244f-5 -7.827603f-5 -0.000104429855 7.888723f-5 2.1203052f-5 -9.137475f-5 3.0502559f-5 -4.9623282f-5 0.00013595213 -9.6175085f-5 -0.00016781004 -9.7771044f-5 -9.766861f-5 2.7287391f-5 3.6276582f-5 5.88302f-5 -8.787865f-6 8.905188f-5 -0.00015503516 -9.066963f-5 1.4968664f-5 2.5622288f-5 0.00019967668 0.000116346375 -0.00018323532; -0.00011968639 -6.4881184f-5 -7.5340424f-5 -8.564484f-6 2.9151195f-5 8.673922f-5 0.00013007522 0.00011180738 2.531736f-5 0.00010047653 -9.797503f-5 2.0182908f-5 -2.4685422f-5 -3.907128f-5 0.00020842819 6.354373f-5 -0.000109219036 -0.00029437387 4.526577f-5 1.0579001f-5 -0.00017912463 1.886779f-5 -8.3849874f-5 -0.0001487781 -1.4718127f-5 -1.117974f-5 0.00010062766 5.2917796f-5 -0.00012782076 6.568424f-5 1.89833f-5 -9.885063f-5; -0.0001281493 5.1810443f-5 -7.875596f-5 5.192243f-5 -0.00023516471 -0.00011426135 -7.992129f-5 7.2869866f-6 4.1526073f-6 -0.0001448735 2.0216563f-5 -3.6963153f-5 -1.4871186f-5 1.5065521f-5 -0.00013293598 -7.437741f-5 7.642428f-6 -9.044517f-6 5.4886776f-5 -2.9375393f-5 9.389974f-5 8.7027445f-5 1.0731632f-5 4.174249f-5 0.00018771611 -4.263645f-5 9.404548f-5 9.3180155f-5 -5.4659795f-5 -2.5987923f-5 3.882569f-5 0.0001398267; -3.879481f-5 0.00011352305 -9.8567434f-5 0.00012251282 0.000103882434 5.234319f-5 4.2717864f-5 -1.1961502f-7 5.9412516f-5 2.1563444f-6 1.638479f-5 1.4651811f-5 -0.00019668961 6.993807f-5 4.848517f-5 -0.00011755133 0.00020821662 0.0001699059 -5.28508f-5 -6.7263672f-6 -8.098693f-5 0.00020619712 -0.000110426976 -7.232615f-5 -3.8568796f-5 0.0001947475 -8.744364f-5 -0.00012160427 -7.586526f-5 7.977659f-5 -6.6188095f-6 -1.9758476f-5; 2.8003205f-5 -9.2497976f-5 -4.043778f-5 5.4831104f-5 -2.5645375f-5 0.00015186521 6.9910566f-5 5.024594f-5 1.2075374f-5 7.432783f-6 -0.00029665526 -8.439167f-5 -6.062176f-5 -4.6675344f-5 -0.00010320278 -2.2737071f-5 -0.00015129619 0.00012743498 3.5777426f-5 2.2504166f-5 -0.00023091429 -6.0184455f-5 -5.2757245f-5 2.9380808f-5 5.6753088f-5 -0.00013953698 -4.0251787f-5 2.1025833f-5 -0.000104673854 -4.1305993f-5 -7.868417f-5 -0.00013540505; -0.00016143935 -5.371446f-5 4.791034f-5 8.175345f-5 4.3985463f-5 8.29942f-5 -9.292006f-5 0.000104060324 6.066447f-5 1.9719671f-5 -7.85131f-5 9.7772754f-5 0.00013591319 3.251802f-5 5.8384354f-5 7.394159f-5 -5.107586f-5 -1.0390376f-6 -0.00011742764 0.00012590075 8.252895f-5 -3.4623697f-6 -7.1420836f-5 -1.8794074f-6 -6.683018f-5 2.0362547f-6 0.0002187233 -9.833961f-5 -5.8926267f-5 5.8382688f-5 1.8845769f-5 0.000105432904; -7.8962454f-5 -0.00010326108 -9.4847295f-5 6.9203576f-5 -4.6227968f-5 -0.00012585461 -3.1159765f-5 3.0000227f-5 -3.87601f-5 4.633176f-5 6.067066f-5 0.00012535532 0.00013758896 0.00010056129 0.00013262278 -9.5449985f-5 -2.3417606f-5 0.00015210822 5.1370484f-5 5.2384406f-5 -2.8114846f-5 -0.000121835394 -7.6522716f-5 -9.962617f-5 -3.960785f-5 3.6469646f-5 -0.00011119089 1.2355804f-5 0.00013495881 2.4840698f-5 -0.00011386149 -6.939414f-5; -4.2084744f-6 0.00010505131 -9.925014f-5 0.00018682017 0.000108131586 -0.00010784809 6.755648f-5 -0.00011852131 -0.00015022726 1.1449211f-5 3.0795323f-5 -1.827585f-5 0.00014257734 -2.0724072f-5 -1.37039815f-5 0.00020979732 7.1506714f-5 -6.604352f-5 4.0597522f-5 -5.6574635f-7 2.8760844f-5 0.00022890595 0.00015147924 8.212723f-5 2.262607f-5 0.000105110936 -2.7469f-5 9.118271f-5 -4.349126f-5 -5.245039f-6 -9.265164f-5 3.963941f-5; 9.4585725f-5 -2.8760844f-5 -0.00010431295 0.00019674671 5.8289856f-7 3.716772f-5 -4.13217f-5 6.3016385f-5 0.00011857731 -6.33699f-5 1.9391973f-5 -2.9218108f-5 -4.4959743f-5 -8.21554f-5 -6.81751f-5 1.611065f-5 -0.00012617851 -1.8765762f-5 3.1087788f-5 2.3777855f-5 -7.428167f-5 2.107615f-5 -9.405588f-7 1.3557284f-5 -6.763632f-5 -1.1624477f-5 -9.202045f-5 -1.7701168f-5 -7.156466f-5 -6.520898f-5 -2.2115706f-5 -1.6929214f-5; 0.00021064165 2.1372556f-5 -7.626245f-6 2.868385f-5 -8.9367604f-5 9.563482f-6 -0.00011749249 3.8458766f-5 -6.849421f-5 0.00018097578 -0.00011163843 0.00011627151 -0.00011370789 -3.6576606f-5 1.8051842f-5 -4.5975758f-5 4.6176337f-5 -9.313594f-5 0.00011604867 -1.04927185f-5 0.000194693 -8.863253f-5 0.00011178962 -7.781905f-5 -6.536762f-5 9.6781274f-5 0.00014316471 -0.00013461692 -5.1645955f-5 2.5063971f-5 -0.00015983968 -0.00011226741; 0.00014799362 -0.00010122516 6.491243f-5 -5.942205f-5 -0.00017406336 2.3699038f-5 -1.2457643f-5 3.9201204f-5 6.752451f-5 8.711015f-6 -0.00012378763 1.9242521f-5 0.000113238886 8.588169f-5 -3.8362967f-5 1.9694119f-6 9.453815f-5 -0.00016475908 2.1907303f-5 0.00017243795 -0.00012853945 -7.473142f-5 -1.2495393f-5 5.9294043f-6 -0.00010307392 -3.1085136f-5 9.78786f-5 9.589323f-5 2.2377502f-5 5.4121037f-6 0.00017126556 -5.7225094f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[0.00015311367 -6.836493f-5 8.0815415f-5 3.483711f-5 -5.9641185f-5 -6.697441f-5 -9.609388f-7 3.3237433f-5 -0.00010410801 0.00016856473 -0.00015339449 -0.0001287553 -0.0001928175 -0.00014561691 7.9893594f-5 5.6119476f-5 -0.00019502112 -7.340051f-5 -3.7724472f-5 2.5733238f-5 -6.35553f-5 -0.00019182274 0.00012545878 -9.238992f-5 -0.00016478547 0.00020241512 4.8010796f-5 2.709138f-5 1.9481222f-5 8.3281084f-5 -2.044384f-5 4.0562332f-5; -3.2208445f-5 4.124723f-5 0.00010546549 4.8438145f-5 -3.104934f-5 6.9943744f-5 -3.9493334f-5 -0.000108901586 -3.9575127f-5 -1.8672137f-5 -8.560589f-5 1.2778468f-5 -0.000106924024 -6.0915983f-5 -0.00019148811 0.00014982845 8.7335706f-5 0.00017142064 1.4027419f-5 -2.6512165f-5 6.152672f-5 5.084371f-5 -2.0621076f-6 1.5555894f-5 2.9060106f-5 -2.4315843f-5 -6.0772214f-5 8.797902f-5 -4.6590427f-5 -2.9821642f-5 -2.5416895f-5 0.000120426135], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))
Similar to most DL frameworks, Lux defaults to using Float32
, however, in this case we need Float64
const params = ComponentArray(f64(ps))
const nn_model = StatefulLuxLayer{true}(nn, nothing, st)
StatefulLuxLayer{true}(
Chain(
layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
layer_2 = Dense(1 => 32, cos), # 64 parameters
layer_3 = Dense(32 => 32, cos), # 1_056 parameters
layer_4 = Dense(32 => 2), # 66 parameters
),
) # Total: 1_186 parameters,
# plus 0 states.
Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses
where,
function ODE_model(u, nn_params, t)
χ, ϕ = u
p, M, e = ode_model_params
# In this example we know that `st` is am empty NamedTuple hence we can safely ignore
# it, however, in general, we should use `st` to store the state of the neural network.
y = 1 .+ nn_model([first(u)], nn_params)
numer = (1 + e * cos(χ))^2
denom = M * (p^(3 / 2))
χ̇ = (numer / denom) * y[1]
ϕ̇ = (numer / denom) * y[2]
return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)
Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
axislegend(
ax,
[[l1, s1], [l2, s2]],
["Waveform Data", "Waveform Neural Net (Untrained)"];
position=:lb,
)
fig
end
Setting Up for Training the Neural Network
Next, we define the objective (loss) function to be minimized when training the neural differential equations.
const mseloss = MSELoss()
function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
return mseloss(pred_waveform, waveform)
end
loss (generic function with 1 method)
Warmup the loss function
loss(params)
0.0006958535815598461
Now let us define a callback function to store the loss over time
const losses = Float64[]
function callback(θ, l)
push!(losses, l)
@printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
return false
end
callback (generic function with 1 method)
Training the Neural Network
Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
optprob,
BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
callback,
maxiters=1000,
)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [0.00023612219956674422; -2.513400067978581e-5; -3.2675197871853696e-5; -2.0179813873248588e-5; 0.00015708889986840773; 0.0001543777616461189; -2.973885057140648e-5; -4.1926006815629534e-5; -5.918972601649801e-5; 0.00013178471999713624; 0.00017165626923077278; 0.00014526229642782183; 6.089618182154646e-5; 9.909419168240745e-5; 0.00010964333341677752; -9.503452747589492e-5; -4.332042954041067e-5; -8.982779399949123e-5; -2.2143924070382182e-5; 9.249347931467853e-5; 9.025256440502084e-5; 0.00021267830743441736; 2.5678540623600378e-5; 4.852687197849734e-5; -0.00017346930690117663; 7.325891601785241e-6; -8.738255564814915e-5; -7.42533229640715e-5; -7.747681229368391e-5; -9.708107245375845e-5; -1.084549967344934e-5; -3.4601809602407944e-5;;], bias = [3.1842010972702857e-16, -4.6426554631326726e-17, -1.5209210934369417e-17, -8.287734706682776e-18, 1.4870830680360255e-16, 5.977534917837337e-17, -4.6767662011185664e-17, 7.736453270486068e-19, 1.4556459732999928e-17, -5.383419407134108e-17, 1.1833894926554097e-16, 1.6797759000693582e-16, 4.9248505984823325e-17, 1.3114443334117987e-16, 1.341491380240319e-16, -5.524732122714292e-17, -6.876325088487355e-17, -9.308419682382967e-17, -2.6577800597144283e-17, 5.277128582726111e-17, 7.791221690291267e-17, 1.442739578259057e-16, 3.3336433814315513e-17, 4.567670455197623e-17, 9.09881328595218e-17, 1.081566871623095e-17, -7.803332859055593e-17, 3.80988571685256e-18, -7.65360929844345e-17, -1.3030959450957468e-16, -1.630904475848443e-18, -6.245152700058246e-17]), layer_3 = (weight = [-2.791240185450552e-5 3.5514681805473515e-5 8.910834149163052e-5 -2.8856185491502846e-6 5.811809533211379e-5 2.563500820711326e-5 8.715487783945952e-5 -4.124644183149804e-5 5.751317585407316e-5 -8.515361550793186e-5 -0.00011524744913745301 8.158479759172914e-5 0.00020505974826108206 -3.564236830982597e-5 8.627625502583106e-5 -1.9652376505541724e-6 -0.00017245204653419812 0.0001384343137050325 0.00012703929096631605 0.0001820193370821226 -1.61683441530171e-5 -4.631106857975483e-5 5.7709830436472516e-5 -2.931567578172622e-5 3.100130707193465e-5 0.00015202038921587115 9.670541621915093e-5 -9.766767469157596e-6 1.1996181094106069e-5 6.553446437635287e-5 4.115243209697246e-5 -0.0001287919352745283; -8.611113417819212e-5 -0.00019770272893760463 3.7177359163304e-5 4.659534414718733e-5 2.1353706588411195e-5 -0.0001284266630226971 -7.989347187498679e-5 -2.919915286786023e-6 2.5745365464953637e-5 2.5064754199838354e-5 -2.9517821621278145e-6 -9.476456153600119e-5 0.00013712760000094868 -2.480495396406338e-5 -5.7489186756296854e-5 0.0001477597094799997 -7.115565796079138e-5 2.4079918775953604e-5 -0.00020688824707571534 -4.654723426023067e-5 -7.415847477199107e-5 -0.00015130551883877112 9.93476120948408e-5 5.461123027094648e-5 1.7355804348043566e-5 -0.00014336311261098982 7.519144736767184e-5 0.00012091817797232643 7.917224746966991e-5 -0.00012926769096333044 0.0001045057640022211 -9.476679525500237e-5; 4.4162818788409186e-5 -0.0001228525536378532 -5.603512838824302e-5 0.00015915438122938723 -3.132795175296265e-5 -0.00010079564092556044 -2.5545151608452463e-5 0.00014469881423174772 -0.00012931081003152357 -0.00014411182494476627 7.778119904579701e-5 -0.00012879219432469555 -8.703259957364349e-5 2.0010148339868553e-5 2.447483600299954e-5 -0.00017886839052770677 -3.1491142424729483e-6 -3.307420340824352e-5 5.509048494168894e-5 1.313095869882796e-5 0.00013203505366000867 -7.89491561834279e-5 3.6646638157622565e-5 0.00010402091427131936 -0.00016975008768603083 -6.728153241674091e-5 1.1973981393755379e-5 0.00011625075855882362 -7.304713947341308e-5 5.81675019686445e-5 0.0002652136725129872 -3.3017276315871814e-5; 1.3495772313172322e-5 -2.488341162080874e-5 4.9808746174843764e-5 3.4726031882969292e-6 -4.5796496502435255e-5 0.00016122823179482118 -1.7527607647767403e-5 -4.794126327016452e-5 -1.3021585269432239e-5 -9.690194966954183e-5 5.834065776449876e-5 -0.00011957397183853417 5.857676191392283e-6 6.19307510003517e-6 -5.507236384421291e-5 0.00019769784883845123 0.00011218545440368288 0.00012867748917795303 8.880527055606257e-5 -4.4878146960236826e-5 4.819446220079017e-5 2.0956749344038013e-5 -3.851063267194665e-5 -8.964496047685016e-5 -7.69854854441599e-5 -2.1472904714587728e-5 -0.00010256604256557691 -3.8432023225885734e-5 9.593418834281788e-5 4.647846033929357e-5 3.6398701544553345e-5 0.00022579305694688082; -0.00017359728231644407 -3.285593853185073e-5 0.00014090195467609805 0.00011346520617984109 -0.00011462544940264214 0.00010537357733732742 9.517703578488791e-5 0.00016430012469094164 9.302907112885497e-6 -2.2758611937440276e-5 -6.133760912894364e-5 -1.853819818030015e-5 0.0001891080697011946 -6.932676530832422e-6 3.5218161013054466e-5 0.00010927908582993013 0.000145076084247864 6.332792831266426e-5 -0.0001573113172137377 -8.011922769334025e-5 -0.00026497769291132555 0.0001317612564369055 -0.00016616272337820502 -8.260452019136462e-5 0.00028707240702965226 -0.00018144035717092592 -0.00018469011998355274 8.40798034393978e-5 -4.531262534988985e-5 -8.829952135534783e-5 -0.00018447901534933734 -8.737778849287059e-5; 0.00012059843262739414 0.0001806042664296844 -3.36775092679158e-5 -7.49031578694551e-5 -5.745214460567939e-5 -0.000166851391472494 6.824014687062357e-5 8.242228338112222e-5 6.806908910711359e-5 -0.00011807302965699915 7.761839784997051e-5 -1.6270366390573952e-5 -6.200684314045065e-5 -9.073740241966277e-5 -9.764994050293341e-5 -0.00016809671527390416 -8.681775674140917e-5 8.397943289800515e-5 5.1303656587508796e-5 0.0001392550575159393 6.015698360478473e-5 -3.4482787074168842e-6 5.212581041080436e-6 -0.00012107732348759573 2.6299229798847692e-5 -4.1462427393168766e-5 6.845237200231268e-5 -0.0002647253468554558 9.609177047004002e-5 0.00019503048221430685 0.00014938767250477464 0.00010158791243647333; 7.514338061801921e-5 7.870816329156186e-5 2.9856332494058687e-5 0.00010356376582418727 -1.9471338655475306e-5 -5.008269659816376e-5 7.341050944884423e-5 6.290895247726737e-5 0.00011339965998555413 -2.5323438658280958e-5 9.214652654860264e-6 -4.962412209356978e-5 -3.924984648655178e-6 4.170763053476119e-5 0.00010166570858923762 -2.0568500312987393e-5 5.250124808295654e-5 -8.31124840264509e-6 0.00011944900028535316 -4.061356163223852e-5 -5.106806498592327e-5 -6.084357597508619e-5 9.650816058634917e-5 0.000101577400291683 9.941932943541836e-5 -6.839799176308136e-5 -0.0002195523295631961 7.658028765941665e-5 4.745673208159047e-5 -7.857176684414693e-5 0.00015463262092449884 8.866595398980113e-6; -6.8624804518562e-5 2.979790358766221e-5 4.36300992787281e-5 -0.00022972428498361955 7.898889951137264e-5 -1.572181328994322e-5 -9.402970547127298e-5 -0.00013569631348401995 -9.275677668667146e-5 1.023734100998244e-5 -9.72892908266799e-5 0.0001715282464628981 9.032902514989742e-5 5.793025182172945e-5 -9.352570052230051e-6 0.00011568684338088433 -4.3017617392267435e-5 -9.176335444878322e-6 9.959437509495071e-5 5.590537100759617e-5 4.7252125548751104e-5 -3.199420691272692e-5 -0.00012075525093861618 -1.964928244162103e-5 0.00010845381391662572 -1.7858163593716428e-5 -2.5873815067840763e-5 8.566192762189808e-5 -2.6091679066682346e-5 3.3311546103353186e-6 3.94794201183529e-5 -4.555500574680734e-5; -0.00016778168778142647 5.4447405415373116e-5 -3.762276596131714e-5 2.0986862506503415e-5 2.9917607531786226e-5 2.36681602077058e-5 8.941934661964906e-5 -7.916894959959179e-5 6.525145050942975e-6 -2.2047721755030084e-6 2.7425322838496878e-5 6.791593055498889e-5 5.8070439405525425e-5 4.084977204587367e-5 -0.00017214051026127238 -0.00021774599430824577 9.907660874563293e-5 4.097664655677093e-5 7.877489711648803e-5 3.860449519964918e-5 7.506161942404745e-5 0.00016945151160729322 -6.246485704065282e-5 4.051909068624515e-5 -7.947433590345066e-6 -0.00029482879894527584 3.949028119353626e-5 -0.00012470036439805255 1.9230477261259634e-5 6.298030653527993e-5 0.0001258401056078916 -7.785309994102434e-5; 3.694937485563895e-5 7.337149080995435e-5 7.124264747545107e-5 -1.105000838578721e-5 -8.794813830462279e-5 -7.178125152421213e-5 -4.9919519435162645e-6 -0.00017919496231770133 -5.608659622084267e-5 -0.00011781774147237941 0.00011107079562208813 -1.1324702161060642e-5 2.0755119560641217e-5 -0.0001010733619361337 -0.00015456248430126854 8.263791032078051e-5 -4.2629177830526436e-5 3.750043405744514e-5 -8.616047768732924e-6 9.517862717330742e-5 -5.237818250606497e-5 0.0001150629971496751 -1.998519902117117e-5 -2.7581291494420367e-5 -0.00014535365399495476 9.774915023881885e-5 2.4775009744924186e-5 -0.0001280029107030651 3.54399556206837e-5 3.0397205858033572e-5 5.2606555339302326e-5 7.067427116276125e-6; 0.0001612396341678138 -0.000122379958496962 9.070945509469866e-5 -2.64181957155907e-5 0.00014192937362700098 -8.285128876573093e-5 -6.190465080558527e-5 -9.94288109191415e-6 -6.675087513408771e-5 1.6410231752024757e-5 7.378460270579412e-5 -0.00010714012330404232 -0.00011537588886690468 -0.00018211842268479065 4.7779431611933264e-5 0.0001759728080522546 -3.667862563284299e-5 8.23114498948401e-6 0.00013144185197765722 -9.005174376775546e-5 -3.417524510419039e-5 -1.0601331609125673e-5 -3.241734827875372e-5 1.701982967083169e-5 -1.2919519917110567e-5 -0.00013584615533014053 -0.00016282131707535182 2.3713399628693145e-5 2.9067063793124647e-5 1.4334874342527453e-7 0.00014221601725316382 3.83234261115341e-5; -3.6373636935906075e-5 0.00019439444735731 -0.00014541148698803951 5.517809699490665e-5 -1.925751053526327e-5 0.00012144596552794515 -1.7394303318313205e-5 -0.00010971750885396231 0.00011628077049358832 8.658886076269207e-5 -0.0002987813538816584 0.00010044947270622117 7.015004252437235e-6 -0.00014156425163996147 -2.5127804070093265e-5 -0.00010307233310918086 5.970077767820137e-5 -0.00012412980573166883 8.732467108025215e-5 5.7502605381572996e-5 7.652477801862388e-5 -0.00010410154641752753 -0.0001393827449242491 -4.299417265023984e-5 0.0002797537265400577 -2.4567235191669427e-5 -1.720372779850184e-5 -7.960632206554384e-5 -7.103258099554885e-5 -0.00014223333415024666 1.957065182473501e-5 0.00014473577470603416; 6.48188108393864e-5 0.0001023192710058714 -8.518751909239714e-5 2.416797016532594e-5 5.96862102294737e-6 -2.6386327393521305e-5 9.26425820730577e-5 1.087827896892059e-5 5.218933457992469e-5 -0.00010075055232253449 -8.302217955450139e-5 -9.112512978707038e-5 9.214593975822869e-5 -0.00010145591275748377 9.482203696872524e-5 -3.2497962623420615e-5 -9.596702492079264e-5 4.347355226538566e-5 -0.00026473785308870056 -8.278580738516037e-6 0.0001481525871136175 2.6284919128118932e-5 -6.796176220039402e-5 7.306217250199489e-5 0.00017351597872813622 3.883936753165411e-5 0.00010524828777899457 -0.0001416116827235608 7.729902082839475e-5 7.260901130986819e-5 -0.00011876985425577063 -6.744403416039919e-5; -4.721010181073371e-5 0.00013004238290760276 -6.490225486176477e-5 -8.382733348243329e-5 4.398080283009109e-5 0.00010205242152606728 0.00011588098287630437 -0.00015620038165197164 1.8599188579726863e-5 5.2626441275250356e-5 0.000107843356191196 2.6564243044602952e-5 -0.00011949887983918573 7.850126722642782e-5 5.0327667950667515e-5 -1.6588784901061082e-5 7.245149036917448e-5 -4.57921304490734e-5 8.305977836299297e-5 -4.463877475190349e-5 -0.00012977445284322278 6.955939180498649e-5 0.00012434047685984132 -5.9460413367007056e-5 0.00016210041222676307 -0.00010172593455480803 3.467544967007449e-5 3.683303303715394e-5 -0.00016270837649189312 -0.0001101146916784197 3.195700274828432e-5 -0.0002170752631076464; -5.002281873506587e-5 6.613142354276326e-5 -6.915101662648535e-5 -8.28890796159179e-5 9.445575155646546e-5 7.554669459446318e-5 8.527428612673957e-5 -0.00020426661908764068 1.283966204617678e-6 -3.7607513925965905e-5 -6.88922253660565e-5 -2.040201875692789e-5 -7.979242477936294e-5 0.00012407856151789966 5.941129452062438e-5 -1.082752974375564e-5 -2.2537577800314394e-5 0.00014073820525177316 1.7228948442286915e-5 0.00011617877784979103 0.00011073583040947589 -0.0001226343618812636 3.640785612028386e-5 -8.817833700354801e-5 6.712985984137391e-6 0.00014089027276590745 -6.92897836901013e-5 -2.8233512105863024e-5 -7.731789858872353e-6 -0.00011158573293264427 -0.00012672627313200106 -9.382805260734322e-5; 0.0003866865308679361 9.487172155219846e-5 1.342034647111021e-5 -5.00524938672519e-5 3.6725517014675556e-5 0.00020219382803927398 7.776372603418047e-5 9.697373115501878e-5 -0.0001647524031852632 -0.00022363281369167844 2.1064784044675914e-5 7.06453292820954e-5 -8.80726502202715e-5 0.00011986125263661123 6.981214482379654e-5 0.00010600466429764526 5.530301586419835e-5 0.00019347846773030533 -7.405740240056328e-6 5.714601228988006e-5 -1.8949815972742598e-5 -3.547542565349429e-5 0.00011348194038307075 0.00010363695127468028 4.485441331208702e-5 -3.767331924090909e-6 4.180836639875921e-5 7.77996692647861e-5 8.842841725141482e-5 4.3164546690378834e-5 -3.880446730020578e-5 2.7851113367221865e-5; 5.57467761934349e-5 9.655012457709003e-5 5.8983023914817434e-5 -2.8591311424195786e-5 -3.0326992932058207e-5 -2.5292374052040214e-5 -0.00011451887200043372 -2.1801695187849544e-5 0.0001507924104186992 -4.1552454935549937e-5 4.986845363939151e-5 -9.040827659580556e-5 -5.581728634230738e-5 7.652085390484705e-5 -0.00012444792752309974 -5.404357158501787e-5 -6.203656027016104e-5 0.00013067735620618686 -7.273271665036072e-7 2.575595375285813e-5 8.289926393966069e-5 0.0002485203809880773 -0.00016194975556961086 0.0001591497499573335 -0.00010881859576718938 -9.478878870528924e-5 5.22115775636201e-5 2.7054628714993093e-6 0.00023685137478944815 7.28223464063631e-5 -5.072099826453991e-5 -9.497538063830541e-5; -8.185978681862587e-6 3.988247276160208e-6 -3.3521282371490115e-5 4.5637971429263585e-5 4.092933909721619e-5 -7.499299658359057e-5 7.367310923162588e-5 0.00020765874092338322 -2.6921486774351118e-5 0.00023510736378246156 3.966594545506672e-5 -0.0001410169641034579 3.5144712687345494e-5 -3.6536762986924235e-5 8.722225198160258e-5 1.894159345700363e-5 -7.921463971266628e-6 8.561644813615351e-5 -4.742884606950232e-5 -9.826622385427848e-5 7.513241894646999e-5 -6.447884492939087e-5 0.00013883195083925248 -4.744750890078328e-5 -0.00024664271017666193 -6.516406070127866e-6 7.998012965419332e-6 -0.00015884051367303753 -6.945509062044739e-5 0.0001930919827414811 6.159894860919329e-5 -2.356285562692945e-5; 5.3375193493142166e-5 2.0030109117719204e-5 0.00015552322463968627 -0.00010947247766984716 -6.138837317800403e-5 -1.5761808862832345e-5 3.5052127605517044e-5 -3.5982824665049236e-5 4.6769566216018305e-5 0.0001832671931702307 -1.1139825000529656e-5 0.00012305199368760033 -1.8216449570398693e-5 0.0001677493945685873 -1.3187760664151194e-6 6.304155420416157e-5 0.00012605303516516252 -8.33202703865504e-5 0.00022602313572911453 -0.0001295314779049049 -4.386935517636863e-5 1.3079000273604262e-5 1.7544600159957843e-5 2.7808260602389747e-5 3.054088472756783e-5 -0.00016386830804498924 -2.861004952087759e-5 -2.9266668315767802e-5 3.930576520410743e-5 0.00014795018967610871 1.9546625372235956e-5 -6.246739769216782e-5; 0.00014737550596220493 -5.93201873696105e-7 3.601220528685474e-5 -8.551725254187443e-5 -9.62480676618906e-5 5.5955903421617337e-5 -0.00019800184740711042 -8.291264334953747e-6 1.1981768197761905e-5 -4.267471153546309e-5 5.245206053112774e-5 -0.0001422010150647329 -4.718161066670642e-5 -1.1685192971742151e-5 -7.386744759536156e-5 3.688342117562771e-5 -3.699034054556792e-5 0.00014895722638790274 -2.6527058749022705e-5 4.760789529860783e-5 8.665085847929842e-5 8.917830059981821e-5 -8.209021829793285e-5 -5.7008933750007965e-5 -6.5497547850225075e-6 -5.1910604766133516e-5 -0.0001649991368551486 -5.668584121416808e-5 2.422440634723785e-5 -0.00012473344899739216 -3.645610335774833e-5 -0.00024995900962062185; -3.914703074667432e-5 0.00011490544466244555 7.834255711816747e-5 3.439352485759684e-5 -4.0021117000715975e-5 -6.317517499962462e-5 -0.00015996733372855435 -7.358490211015076e-5 0.0001787235164566759 -0.00016248375277325112 -8.59901547198322e-6 -2.7957357359692406e-5 -7.054537081682576e-5 -4.403816592354815e-5 -0.00018190229411527086 8.48344939196341e-5 0.00010539020922909051 3.304817119090341e-5 9.73521750903552e-5 0.00017585603245709545 6.33238237252318e-5 2.816562164400938e-5 -8.908201833828024e-6 5.3851036228598324e-5 -7.339785179180703e-5 -6.264743524195126e-5 -8.240959823421677e-5 0.00022895949268889154 -8.410075633843894e-5 -2.2149931736630586e-5 -6.119279669189324e-5 3.924232487509523e-5; 6.505752741969247e-5 -7.585656215064448e-5 0.0001022090763146971 -0.00016052186777210103 -4.2521613184110504e-5 -8.340661964004864e-5 -0.0002748615189627764 8.681970564089596e-5 -7.827876482019165e-5 -0.00010443258736607274 7.8884495335189e-5 2.120031956317118e-5 -9.137748213810794e-5 3.0499826491234322e-5 -4.962601463169925e-5 0.00013594939434575446 -9.617781700193991e-5 -0.0001678127702032171 -9.777377647673943e-5 -9.767134554543976e-5 2.7284659038069524e-5 3.627385001190985e-5 5.882746688462575e-5 -8.790597464889954e-6 8.904914636539849e-5 -0.00015503789120679994 -9.067236180975628e-5 1.4965931652348789e-5 2.561955613806172e-5 0.00019967394825677024 0.00011634364261431352 -0.00018323804772785106; -0.00011968793573162122 -6.488273087730195e-5 -7.53419708793937e-5 -8.566031321110596e-6 2.914964803363438e-5 8.673767566375168e-5 0.00013007367211552787 0.000111805830746149 2.53158131804841e-5 0.00010047498195358893 -9.797658007431761e-5 2.0181360886051015e-5 -2.4686968610196974e-5 -3.9072828295911075e-5 0.00020842664067327122 6.354218472420166e-5 -0.00010922058303681634 -0.0002943754125603393 4.526422249777673e-5 1.0577454266740425e-5 -0.00017912617935639927 1.8866242461235834e-5 -8.385142158790985e-5 -0.0001487796444429964 -1.4719674292280809e-5 -1.1181286974037636e-5 0.00010062611086918828 5.291624889518816e-5 -0.00012782231082515087 6.568269141876085e-5 1.898175374530668e-5 -9.885217608958018e-5; -0.00012815003494999477 5.180970990528794e-5 -7.875669464504825e-5 5.192169780690659e-5 -0.00023516544527176682 -0.00011426207995776025 -7.992201929249353e-5 7.286253978750925e-6 4.151874666209773e-6 -0.00014487423097983093 2.0215830293248528e-5 -3.696388514747845e-5 -1.4871918690399708e-5 1.5064788815441628e-5 -0.00013293671284139374 -7.437813976815262e-5 7.641694967392884e-6 -9.045249978437417e-6 5.488604299222557e-5 -2.9376125464689203e-5 9.389900739967511e-5 8.702671263797783e-5 1.0730899223349752e-5 4.174175635177193e-5 0.0001877153794999442 -4.263718212990594e-5 9.40447448306871e-5 9.317942246758938e-5 -5.466052746435989e-5 -2.5988656059264212e-5 3.8824955911482314e-5 0.00013982597235787182; -3.879345111264155e-5 0.00011352440689744776 -9.856607592144654e-5 0.00012251417994786414 0.0001038837921624745 5.2344550040385376e-5 4.271922209238039e-5 -1.1825676597964803e-7 5.941387409631382e-5 2.1577026655991796e-6 1.638614905469118e-5 1.4653169655600998e-5 -0.00019668825059842107 6.993942717609869e-5 4.848652808515197e-5 -0.00011754997338347163 0.00020821797572363458 0.00016990726491723223 -5.284944070797992e-5 -6.72500898862029e-6 -8.098557244000943e-5 0.00020619847548013776 -0.00011042561762213361 -7.232478908907354e-5 -3.856743804126077e-5 0.0001947488539554961 -8.744227905470843e-5 -0.00012160291460521389 -7.586390218410534e-5 7.977794795816722e-5 -6.617451235690176e-6 -1.9757117643059733e-5; 2.7998830660777245e-5 -9.250235063921492e-5 -4.04421535899281e-5 5.482672884279067e-5 -2.564974977248522e-5 0.00015186083862827799 6.990619101595337e-5 5.024156587348645e-5 1.2070999105685523e-5 7.4284083671922225e-6 -0.00029665963131663824 -8.439604619015685e-5 -6.062613658689292e-5 -4.667971927521221e-5 -0.00010320715421076956 -2.2741445831722117e-5 -0.00015130056465243578 0.000127430603129671 3.577305130927039e-5 2.2499791044158653e-5 -0.00023091866216820367 -6.018882970634924e-5 -5.2761619485344876e-5 2.9376433190530547e-5 5.6748713064385885e-5 -0.00013954135113129304 -4.025616192340629e-5 2.102145833812578e-5 -0.0001046782291490865 -4.1310368146165194e-5 -7.868854493101965e-5 -0.00013540942211368804; -0.00016143713205865932 -5.371223583393838e-5 4.791256301062973e-5 8.175567577085367e-5 4.39876859188362e-5 8.299642299491068e-5 -9.291783849517941e-5 0.00010406254673101751 6.066669461391552e-5 1.97218943382948e-5 -7.851087489442285e-5 9.77749770270037e-5 0.00013591540859737104 3.252024193600914e-5 5.838657684464944e-5 7.394381477454898e-5 -5.1073635271007305e-5 -1.0368146179351821e-6 -0.00011742541742320269 0.00012590297787640502 8.253117643719823e-5 -3.4601467662443806e-6 -7.141861338210402e-5 -1.8771844254509685e-6 -6.682795706482034e-5 2.0384776660089008e-6 0.00021872552692614816 -9.83373864558106e-5 -5.892404429074871e-5 5.8384910650357655e-5 1.8847991802193444e-5 0.00010543512703110315; -7.896293823934797e-5 -0.00010326156621661611 -9.484777979881137e-5 6.920309100902994e-5 -4.6228452226866115e-5 -0.00012585509535486083 -3.1160249598115796e-5 2.9999742075660288e-5 -3.8760583954654334e-5 4.6331274681935706e-5 6.067017533721283e-5 0.00012535483551116413 0.00013758847057258183 0.00010056080213186833 0.00013262229480685272 -9.54504691958658e-5 -2.3418090456771622e-5 0.00015210773130317694 5.13699990210345e-5 5.2383921904466276e-5 -2.8115330458176225e-5 -0.0001218358781966554 -7.652320051195674e-5 -9.962665786363733e-5 -3.9608334880109114e-5 3.646916160730061e-5 -0.00011119137776179702 1.2355319306580033e-5 0.00013495832831037673 2.4840213193302343e-5 -0.00011386197434833856 -6.939462469520613e-5; -4.204921557351515e-6 0.00010505486202771477 -9.924658447391702e-5 0.00018682371806232913 0.00010813513868365165 -0.0001078445399308781 6.756003345242376e-5 -0.00011851775857042016 -0.00015022371157658918 1.1452764150685277e-5 3.079887606312631e-5 -1.8272297571593425e-5 0.00014258088990925158 -2.0720519065558498e-5 -1.3700428686332683e-5 0.0002098008720657755 7.151026725648461e-5 -6.603996520983278e-5 4.0601075264818005e-5 -5.621935394031148e-7 2.8764396441935104e-5 0.00022890950126913537 0.00015148279121956484 8.213078613505204e-5 2.2629623666797208e-5 0.00010511448850036366 -2.7465446370236447e-5 9.118626108272071e-5 -4.348770894114844e-5 -5.241486311329362e-6 -9.264808399723152e-5 3.964296261426629e-5; 9.458417872172216e-5 -2.8762389494525852e-5 -0.00010431449817756693 0.00019674516471072497 5.813527042480609e-7 3.716617385012734e-5 -4.13232443787395e-5 6.301483909569214e-5 0.00011857576128763698 -6.337144776975423e-5 1.9390427322070892e-5 -2.9219654326736598e-5 -4.496128866944832e-5 -8.2156944419556e-5 -6.817664297327355e-5 1.6109105043279828e-5 -0.00012618005320460063 -1.8767308294490568e-5 3.108624178653052e-5 2.3776309088271698e-5 -7.428321595566495e-5 2.107460414497023e-5 -9.421046382644121e-7 1.3555737776321174e-5 -6.763786558789185e-5 -1.1626023198781473e-5 -9.202199334308679e-5 -1.7702713823246776e-5 -7.15662060355048e-5 -6.521052609227685e-5 -2.2117251948524173e-5 -1.6930760186235716e-5; 0.00021064155113781067 2.1372456802777182e-5 -7.626344296301072e-6 2.8683750105113883e-5 -8.936770357759573e-5 9.563382963092046e-6 -0.00011749259031712991 3.84586662847085e-5 -6.849431268274759e-5 0.00018097568205921048 -0.00011163852944373974 0.00011627141317948783 -0.00011370798641219708 -3.657670570756263e-5 1.8051742489473493e-5 -4.597585687516213e-5 4.617623814628479e-5 -9.313603844130371e-5 0.00011604857242563755 -1.0492817847267306e-5 0.00019469289943007791 -8.863262813174936e-5 0.00011178951787659916 -7.781915085725463e-5 -6.536771728061027e-5 9.678117455113694e-5 0.00014316461126217655 -0.00013461701494744017 -5.1646054299641064e-5 2.506387206157714e-5 -0.00015983977685319245 -0.00011226750687565636; 0.00014799427837936004 -0.00010122449937404088 6.491308697371781e-5 -5.9421390758472894e-5 -0.00017406269951385914 2.3699698360753403e-5 -1.2456982474361588e-5 3.9201864567473785e-5 6.752517030912436e-5 8.711674965919607e-6 -0.00012378696744924045 1.924318159802425e-5 0.00011323954633775148 8.588235316204688e-5 -3.836230651137615e-5 1.9700720917495916e-6 9.453880731820149e-5 -0.00016475842319556622 2.190796287544588e-5 0.00017243861468381127 -0.00012853879350365252 -7.47307586403375e-5 -1.2494732870948031e-5 5.930064546589381e-6 -0.00010307326199615029 -3.1084475338104506e-5 9.787925766653278e-5 9.589388894022385e-5 2.2378162541314285e-5 5.412763969102861e-6 0.0001712662250815245 -5.7224433546764407e-5], bias = [3.486640601890321e-9, -2.0966671795089382e-9, -7.450824291622188e-10, 1.7276617515888822e-9, -1.0245069130395476e-9, 6.211534062585628e-10, 2.7764508240588894e-9, -1.6964078645708925e-10, -1.2791457569612898e-10, -1.2521079002246152e-9, -6.942847486926146e-10, -8.483362292298823e-10, 4.7490647682188576e-11, -6.538302096702128e-11, -8.320542769748907e-10, 5.588839276625628e-9, 1.2011112013401353e-9, 9.31554117705888e-10, 2.57012339254859e-9, -3.2064647240066264e-9, 1.8235471602165598e-10, -2.7322382693912505e-9, -1.547088190578054e-9, -7.326229339695914e-10, 1.3582576848877263e-9, -4.3747826843988894e-9, 2.2229385090307668e-9, -4.845401516844755e-10, 3.5528081539713753e-9, -1.5458607356643856e-9, -9.934520334612278e-11, 6.602226110721006e-10]), layer_4 = (weight = [-0.0005229489322768489 -0.000744427696793997 -0.0005952474542651974 -0.0006412256983845283 -0.0007357040404820251 -0.0007430372837737655 -0.0006770236307973326 -0.0006428254487190189 -0.000780170892358476 -0.0005074981164557682 -0.000829457358571604 -0.0008048181658884271 -0.0008688803865850869 -0.000821679794298581 -0.0005961692716866898 -0.000619942633482522 -0.0008710839587948043 -0.0007494633677304978 -0.0007137871882734027 -0.0006503293929002428 -0.0007396181820935056 -0.0008678854174934661 -0.0005506040414728361 -0.0007684527891656483 -0.0008408482995605021 -0.0004736473241437501 -0.0006280519672257521 -0.0006489714959138545 -0.0006565813513578144 -0.000592781741270652 -0.0006965067222987789 -0.0006355005392587444; 0.00019687622850310702 0.0002703319607883781 0.000334550258171557 0.0002775228901270265 0.00019803541920192255 0.00029902851046230997 0.00018959137134392495 0.00012018318348448235 0.00018950964277504102 0.00021041262008144723 0.0001434788749212312 0.00024186323116980973 0.00012216074532314354 0.00016816878625956397 3.759665006674623e-5 0.0003789129632155472 0.0003164204622582981 0.0004005054066284545 0.00024311213219034523 0.000202572519359105 0.0002906114892964887 0.00027992841357181143 0.00022702264292163448 0.0002446406586751507 0.0002581448587442617 0.00020476877816618974 0.00016831251522704362 0.0003170637844951167 0.0001824942385120974 0.00019926310842116905 0.00020366787418546353 0.00034951090124138266], bias = [-0.0006760628823594855, 0.0002290847696372865]))
Visualizing the Results
Let us now plot the loss over time
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")
lines!(ax, losses; linewidth=4, alpha=0.75)
scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)
fig
end
Finally let us visualize the results
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
s3 = scatter!(
ax,
tsteps,
waveform_nn_trained;
marker=:circle,
alpha=0.5,
strokewidth=2,
markersize=12,
)
axislegend(
ax,
[[l1, s1], [l2, s2], [l3, s3]],
["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
position=:lb,
)
fig
end
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 128 × AMD EPYC 7502 32-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
JULIA_CPU_THREADS = 16
JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 16
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.