Training a PINN on 2D PDE
In this tutorial we will go over using a PINN to solve 2D PDEs. We will be using the system from NeuralPDE Tutorials. However, we will be using our custom loss function and use nested AD capabilities of Lux.jl.
This is a demonstration of Lux.jl. For serious use cases of PINNs, please refer to the package: NeuralPDE.jl.
Package Imports
using Lux,
Optimisers,
Random,
Printf,
Statistics,
MLUtils,
OnlineStats,
CairoMakie,
Reactant,
Enzyme
const xdev = reactant_device(; force=true)
const cdev = cpu_device()
Precompiling Reactant...
13353.3 ms ? Enzyme
13891.0 ms ? Enzyme → EnzymeGPUArraysCoreExt
Info Given Reactant was explicitly requested, output will be shown live [0K
[0KWARNING: Method definition within_autodiff() in module EnzymeCore at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/EnzymeCore/0ptb3/src/EnzymeCore.jl:619 overwritten in module Enzyme at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Enzyme/nqe7m/src/Enzyme.jl:1561.
[0KERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
14838.0 ms ? Reactant
WARNING: Method definition within_autodiff() in module EnzymeCore at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/EnzymeCore/0ptb3/src/EnzymeCore.jl:619 overwritten in module Enzyme at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Enzyme/nqe7m/src/Enzyme.jl:1561.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
Precompiling Enzyme...
Info Given Enzyme was explicitly requested, output will be shown live [0K
[0KWARNING: Method definition within_autodiff() in module EnzymeCore at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/EnzymeCore/0ptb3/src/EnzymeCore.jl:619 overwritten in module Enzyme at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Enzyme/nqe7m/src/Enzyme.jl:1561.
[0KERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
13120.0 ms ? Enzyme
WARNING: Method definition within_autodiff() in module EnzymeCore at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/EnzymeCore/0ptb3/src/EnzymeCore.jl:619 overwritten in module Enzyme at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Enzyme/nqe7m/src/Enzyme.jl:1561.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
Precompiling LuxEnzymeExt...
13596.0 ms ? Enzyme
755.4 ms ? Enzyme → EnzymeChainRulesCoreExt
901.5 ms ? Enzyme → EnzymeSpecialFunctionsExt
953.0 ms ? Enzyme → EnzymeStaticArraysExt
733.4 ms ? Enzyme → EnzymeLogExpFunctionsExt
723.4 ms ? Enzyme → EnzymeGPUArraysCoreExt
Info Given LuxEnzymeExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Enzyme with build ID ffffffff-ffff-ffff-358a-42117dfcf604 is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
719.9 ms ? Lux → LuxEnzymeExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Enzyme with build ID ffffffff-ffff-ffff-358a-42117dfcf604 is missing from the cache.
[33m[1m│ [22m[39mThis may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling QuadGKEnzymeExt...
13802.4 ms ? Enzyme
Info Given QuadGKEnzymeExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Enzyme with build ID ffffffff-ffff-ffff-358a-42117dfcf604 is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
737.2 ms ? QuadGK → QuadGKEnzymeExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Enzyme with build ID ffffffff-ffff-ffff-358a-42117dfcf604 is missing from the cache.
[33m[1m│ [22m[39mThis may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling EnzymeLogExpFunctionsExt...
13262.1 ms ? Enzyme
Info Given EnzymeLogExpFunctionsExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Enzyme with build ID ffffffff-ffff-ffff-358a-42117dfcf604 is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
721.1 ms ? Enzyme → EnzymeLogExpFunctionsExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Enzyme with build ID ffffffff-ffff-ffff-358a-42117dfcf604 is missing from the cache.
[33m[1m│ [22m[39mThis may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling EnzymeChainRulesCoreExt...
13149.8 ms ? Enzyme
Info Given EnzymeChainRulesCoreExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Enzyme with build ID ffffffff-ffff-ffff-358a-42117dfcf604 is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
719.9 ms ? Enzyme → EnzymeChainRulesCoreExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Enzyme with build ID ffffffff-ffff-ffff-358a-42117dfcf604 is missing from the cache.
[33m[1m│ [22m[39mThis may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling EnzymeSpecialFunctionsExt...
13044.1 ms ? Enzyme
736.4 ms ? Enzyme → EnzymeLogExpFunctionsExt
Info Given EnzymeSpecialFunctionsExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Enzyme with build ID ffffffff-ffff-ffff-358a-42117dfcf604 is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
884.8 ms ? Enzyme → EnzymeSpecialFunctionsExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Enzyme with build ID ffffffff-ffff-ffff-358a-42117dfcf604 is missing from the cache.
[33m[1m│ [22m[39mThis may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling EnzymeGPUArraysCoreExt...
13514.0 ms ? Enzyme
Info Given EnzymeGPUArraysCoreExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Enzyme with build ID ffffffff-ffff-ffff-358a-42117dfcf604 is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
741.5 ms ? Enzyme → EnzymeGPUArraysCoreExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Enzyme with build ID ffffffff-ffff-ffff-358a-42117dfcf604 is missing from the cache.
[33m[1m│ [22m[39mThis may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling EnzymeStaticArraysExt...
13760.3 ms ? Enzyme
Info Given EnzymeStaticArraysExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Enzyme with build ID ffffffff-ffff-ffff-358a-42117dfcf604 is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
940.2 ms ? Enzyme → EnzymeStaticArraysExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Enzyme with build ID ffffffff-ffff-ffff-358a-42117dfcf604 is missing from the cache.
[33m[1m│ [22m[39mThis may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling OptimisersReactantExt...
13113.8 ms ? Enzyme
714.5 ms ? Enzyme → EnzymeGPUArraysCoreExt
735.9 ms ? Enzyme → EnzymeChainRulesCoreExt
1912.5 ms ? Reactant
726.8 ms ? Reactant → ReactantStatisticsExt
Info Given OptimisersReactantExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
791.3 ms ? Optimisers → OptimisersReactantExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling LuxCoreReactantExt...
13283.7 ms ? Enzyme
724.4 ms ? Enzyme → EnzymeGPUArraysCoreExt
1927.7 ms ? Reactant
Info Given LuxCoreReactantExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
756.7 ms ? LuxCore → LuxCoreReactantExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling MLDataDevicesReactantExt...
13698.1 ms ? Enzyme
746.6 ms ? Enzyme → EnzymeGPUArraysCoreExt
1997.6 ms ? Reactant
Info Given MLDataDevicesReactantExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
777.5 ms ? MLDataDevices → MLDataDevicesReactantExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling WeightInitializersReactantExt...
13238.1 ms ? Enzyme
738.0 ms ? Enzyme → EnzymeLogExpFunctionsExt
789.2 ms ? Enzyme → EnzymeGPUArraysCoreExt
873.7 ms ? Enzyme → EnzymeSpecialFunctionsExt
1946.5 ms ? Reactant
Info Given WeightInitializersReactantExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
742.7 ms ? Reactant → ReactantStatisticsExt
748.4 ms ? WeightInitializers → WeightInitializersReactantExt
896.0 ms ? Reactant → ReactantSpecialFunctionsExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling ReactantAbstractFFTsExt...
13173.0 ms ? Enzyme
723.7 ms ? Enzyme → EnzymeGPUArraysCoreExt
1927.3 ms ? Reactant
Info Given ReactantAbstractFFTsExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
727.8 ms ? Reactant → ReactantAbstractFFTsExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling ReactantOffsetArraysExt...
12950.3 ms ? Enzyme
713.1 ms ? Enzyme → EnzymeGPUArraysCoreExt
1925.0 ms ? Reactant
Info Given ReactantOffsetArraysExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
726.1 ms ? Reactant → ReactantOffsetArraysExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling ReactantKernelAbstractionsExt...
13018.4 ms ? Enzyme
705.4 ms ? Enzyme → EnzymeGPUArraysCoreExt
953.0 ms ? Enzyme → EnzymeStaticArraysExt
1977.4 ms ? Reactant
Info Given ReactantKernelAbstractionsExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
709.7 ms ? Reactant → ReactantKernelAbstractionsExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling ReactantArrayInterfaceExt...
12989.6 ms ? Enzyme
736.8 ms ? Enzyme → EnzymeGPUArraysCoreExt
1970.5 ms ? Reactant
Info Given ReactantArrayInterfaceExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
715.8 ms ? Reactant → ReactantArrayInterfaceExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling ReactantSpecialFunctionsExt...
12882.0 ms ? Enzyme
723.7 ms ? Enzyme → EnzymeGPUArraysCoreExt
741.2 ms ? Enzyme → EnzymeLogExpFunctionsExt
869.0 ms ? Enzyme → EnzymeSpecialFunctionsExt
1909.5 ms ? Reactant
Info Given ReactantSpecialFunctionsExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
996.4 ms ? Reactant → ReactantSpecialFunctionsExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling ReactantFillArraysExt...
12937.9 ms ? Enzyme
710.3 ms ? Enzyme → EnzymeGPUArraysCoreExt
1948.7 ms ? Reactant
Info Given ReactantFillArraysExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
706.6 ms ? Reactant → ReactantFillArraysExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling ReactantStatisticsExt...
13059.9 ms ? Enzyme
745.0 ms ? Enzyme → EnzymeGPUArraysCoreExt
1912.4 ms ? Reactant
Info Given ReactantStatisticsExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
743.3 ms ? Reactant → ReactantStatisticsExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling LuxLibReactantExt...
13124.1 ms ? Enzyme
738.4 ms ? Enzyme → EnzymeChainRulesCoreExt
877.2 ms ? Enzyme → EnzymeSpecialFunctionsExt
920.7 ms ? Enzyme → EnzymeStaticArraysExt
719.9 ms ? Enzyme → EnzymeLogExpFunctionsExt
702.9 ms ? Enzyme → EnzymeGPUArraysCoreExt
1933.3 ms ? Reactant
712.5 ms ? Reactant → ReactantStatisticsExt
769.8 ms ? Reactant → ReactantKernelAbstractionsExt
1019.4 ms ? Reactant → ReactantSpecialFunctionsExt
752.1 ms ? Reactant → ReactantArrayInterfaceExt
844.7 ms ? MLDataDevices → MLDataDevicesReactantExt
779.6 ms ? LuxCore → LuxCoreReactantExt
Info Given LuxLibReactantExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
742.7 ms ? LuxLib → LuxLibReactantExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling ReactantNNlibExt...
13243.6 ms ? Enzyme
734.5 ms ? Enzyme → EnzymeChainRulesCoreExt
741.2 ms ? Enzyme → EnzymeGPUArraysCoreExt
922.8 ms ? Enzyme → EnzymeStaticArraysExt
1934.7 ms ? Reactant
705.3 ms ? Reactant → ReactantStatisticsExt
709.7 ms ? Reactant → ReactantKernelAbstractionsExt
Info Given ReactantNNlibExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
1007.5 ms ? Reactant → ReactantNNlibExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling LuxReactantExt...
13183.3 ms ? Enzyme
772.0 ms ? Enzyme → EnzymeChainRulesCoreExt
876.4 ms ? Enzyme → EnzymeSpecialFunctionsExt
933.9 ms ? Enzyme → EnzymeStaticArraysExt
746.3 ms ? Enzyme → EnzymeLogExpFunctionsExt
754.0 ms ? Enzyme → EnzymeGPUArraysCoreExt
709.7 ms ? Lux → LuxEnzymeExt
1923.1 ms ? Reactant
713.6 ms ? Reactant → ReactantKernelAbstractionsExt
723.4 ms ? Reactant → ReactantStatisticsExt
921.3 ms ? Reactant → ReactantSpecialFunctionsExt
716.8 ms ? Reactant → ReactantArrayInterfaceExt
768.9 ms ? MLDataDevices → MLDataDevicesReactantExt
778.6 ms ? LuxCore → LuxCoreReactantExt
768.0 ms ? Optimisers → OptimisersReactantExt
726.9 ms ? WeightInitializers → WeightInitializersReactantExt
717.4 ms ? LuxLib → LuxLibReactantExt
Info Given LuxReactantExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Enzyme with build ID ffffffff-ffff-ffff-358a-42117dfcf604 is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
722.5 ms ? Lux → LuxReactantExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Enzyme with build ID ffffffff-ffff-ffff-358a-42117dfcf604 is missing from the cache.
[33m[1m│ [22m[39mThis may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Problem Definition
Since Lux supports efficient nested AD upto 2nd order, we will rewrite the problem with first order derivatives, so that we can compute the gradients of the loss using 2nd order AD.
Define the Neural Networks
All the networks take 3 input variables and output a scalar value. Here, we will define a wrapper over the 3 networks, so that we can train them using Training.TrainState
.
struct PINN{M} <: AbstractLuxWrapperLayer{:model}
model::M
end
function PINN(; hidden_dims::Int=32)
return PINN(
Chain(
Dense(3 => hidden_dims, tanh),
Dense(hidden_dims => hidden_dims, tanh),
Dense(hidden_dims => hidden_dims, tanh),
Dense(hidden_dims => 1),
),
)
end
Define the Loss Functions
We will define a custom loss function to compute the loss using 2nd order AD. For that, first we'll need to define the derivatives of our model:
function ∂u_∂t(model::StatefulLuxLayer, xyt::AbstractArray)
return Enzyme.gradient(Enzyme.Reverse, sum ∘ model, xyt)[1][3, :]
end
function ∂u_∂x(model::StatefulLuxLayer, xyt::AbstractArray)
return Enzyme.gradient(Enzyme.Reverse, sum ∘ model, xyt)[1][1, :]
end
function ∂u_∂y(model::StatefulLuxLayer, xyt::AbstractArray)
return Enzyme.gradient(Enzyme.Reverse, sum ∘ model, xyt)[1][2, :]
end
function ∂²u_∂x²(model::StatefulLuxLayer, xyt::AbstractArray)
return Enzyme.gradient(Enzyme.Reverse, sum ∘ ∂u_∂x, Enzyme.Const(model), xyt)[2][1, :]
end
function ∂²u_∂y²(model::StatefulLuxLayer, xyt::AbstractArray)
return Enzyme.gradient(Enzyme.Reverse, sum ∘ ∂u_∂y, Enzyme.Const(model), xyt)[2][2, :]
end
We will use the following loss function
function physics_informed_loss_function(model::StatefulLuxLayer, xyt::AbstractArray)
return mean(abs2, ∂u_∂t(model, xyt) .- ∂²u_∂x²(model, xyt) .- ∂²u_∂y²(model, xyt))
end
Additionally, we need to compute the loss with respect to the boundary conditions.
function mse_loss_function(
model::StatefulLuxLayer, target::AbstractArray, xyt::AbstractArray
)
return MSELoss()(model(xyt), target)
end
function loss_function(model, ps, st, (xyt, target_data, xyt_bc, target_bc))
smodel = StatefulLuxLayer(model, ps, st)
physics_loss = physics_informed_loss_function(smodel, xyt)
data_loss = mse_loss_function(smodel, target_data, xyt)
bc_loss = mse_loss_function(smodel, target_bc, xyt_bc)
loss = physics_loss + data_loss + bc_loss
return loss, smodel.st, (; physics_loss, data_loss, bc_loss)
end
Generate the Data
We will generate some random data to train the model on. We will take data on a square spatial and temporal domain
analytical_solution(x, y, t) = @. exp(x + y) * cos(x + y + 4t)
analytical_solution(xyt) = analytical_solution(xyt[1, :], xyt[2, :], xyt[3, :])
begin
grid_len = 16
grid = range(0.0f0, 2.0f0; length=grid_len)
xyt = stack([[elem...] for elem in vec(collect(Iterators.product(grid, grid, grid)))])
target_data = reshape(analytical_solution(xyt), 1, :)
bc_len = 512
x = collect(range(0.0f0, 2.0f0; length=bc_len))
y = collect(range(0.0f0, 2.0f0; length=bc_len))
t = collect(range(0.0f0, 2.0f0; length=bc_len))
xyt_bc = hcat(
stack((x, y, zeros(Float32, bc_len)); dims=1),
stack((zeros(Float32, bc_len), y, t); dims=1),
stack((ones(Float32, bc_len) .* 2, y, t); dims=1),
stack((x, zeros(Float32, bc_len), t); dims=1),
stack((x, ones(Float32, bc_len) .* 2, t); dims=1),
)
target_bc = reshape(analytical_solution(xyt_bc), 1, :)
min_target_bc, max_target_bc = extrema(target_bc)
min_data, max_data = extrema(target_data)
min_pde_val, max_pde_val = min(min_data, min_target_bc), max(max_data, max_target_bc)
xyt = (xyt .- minimum(xyt)) ./ (maximum(xyt) .- minimum(xyt))
xyt_bc = (xyt_bc .- minimum(xyt_bc)) ./ (maximum(xyt_bc) .- minimum(xyt_bc))
target_bc = (target_bc .- min_pde_val) ./ (max_pde_val - min_pde_val)
target_data = (target_data .- min_pde_val) ./ (max_pde_val - min_pde_val)
end
Training
function train_model(
xyt,
target_data,
xyt_bc,
target_bc;
seed::Int=0,
maxiters::Int=50000,
hidden_dims::Int=128,
)
rng = Random.default_rng()
Random.seed!(rng, seed)
pinn = PINN(; hidden_dims)
ps, st = Lux.setup(rng, pinn) |> xdev
bc_dataloader =
DataLoader((xyt_bc, target_bc); batchsize=128, shuffle=true, partial=false) |> xdev
pde_dataloader =
DataLoader((xyt, target_data); batchsize=128, shuffle=true, partial=false) |> xdev
train_state = Training.TrainState(pinn, ps, st, Adam(0.005f0))
lr = i -> i < 5000 ? 0.005f0 : (i < 10000 ? 0.0005f0 : 0.00005f0)
total_loss_tracker, physics_loss_tracker, data_loss_tracker, bc_loss_tracker = ntuple(
_ -> OnlineStats.CircBuff(Float32, 32; rev=true), 4
)
iter = 1
for ((xyt_batch, target_data_batch), (xyt_bc_batch, target_bc_batch)) in
zip(Iterators.cycle(pde_dataloader), Iterators.cycle(bc_dataloader))
Optimisers.adjust!(train_state, lr(iter))
_, loss, stats, train_state = Training.single_train_step!(
AutoEnzyme(),
loss_function,
(xyt_batch, target_data_batch, xyt_bc_batch, target_bc_batch),
train_state;
return_gradients=Val(false),
)
fit!(total_loss_tracker, Float32(loss))
fit!(physics_loss_tracker, Float32(stats.physics_loss))
fit!(data_loss_tracker, Float32(stats.data_loss))
fit!(bc_loss_tracker, Float32(stats.bc_loss))
mean_loss = mean(OnlineStats.value(total_loss_tracker))
mean_physics_loss = mean(OnlineStats.value(physics_loss_tracker))
mean_data_loss = mean(OnlineStats.value(data_loss_tracker))
mean_bc_loss = mean(OnlineStats.value(bc_loss_tracker))
isnan(loss) && throw(ArgumentError("NaN Loss Detected"))
if iter % 1000 == 1 || iter == maxiters
@printf(
"Iteration: [%6d/%6d] \t Loss: %.9f (%.9f) \t Physics Loss: %.9f \
(%.9f) \t Data Loss: %.9f (%.9f) \t BC \
Loss: %.9f (%.9f)\n",
iter,
maxiters,
loss,
mean_loss,
stats.physics_loss,
mean_physics_loss,
stats.data_loss,
mean_data_loss,
stats.bc_loss,
mean_bc_loss
)
end
iter += 1
iter ≥ maxiters && break
end
return StatefulLuxLayer(pinn, cdev(train_state.parameters), cdev(train_state.states))
end
trained_model = train_model(xyt, target_data, xyt_bc, target_bc)
AssertionError("Could not find registered platform with name: \"cuda\". Available platform names are: ")
Iteration: [ 1/ 50000] Loss: 20.523933411 (20.523933411) Physics Loss: 16.931318283 (16.931318283) Data Loss: 2.007483006 (2.007483006) BC Loss: 1.585133195 (1.585133195)
Iteration: [ 1001/ 50000] Loss: 0.017368603 (0.019241145) Physics Loss: 0.000384357 (0.000523635) Data Loss: 0.005318487 (0.007538572) BC Loss: 0.011665760 (0.011178939)
Iteration: [ 2001/ 50000] Loss: 0.015431654 (0.018665703) Physics Loss: 0.001248552 (0.001662074) Data Loss: 0.004322530 (0.006408241) BC Loss: 0.009860572 (0.010595390)
Iteration: [ 3001/ 50000] Loss: 0.015749799 (0.015216216) Physics Loss: 0.000569920 (0.001279173) Data Loss: 0.004014936 (0.004232483) BC Loss: 0.011164943 (0.009704558)
Iteration: [ 4001/ 50000] Loss: 0.009718472 (0.008712115) Physics Loss: 0.002387779 (0.003379629) Data Loss: 0.003175758 (0.002104492) BC Loss: 0.004154935 (0.003227993)
Iteration: [ 5001/ 50000] Loss: 0.004535029 (0.007336445) Physics Loss: 0.000547516 (0.003038818) Data Loss: 0.002077004 (0.001806144) BC Loss: 0.001910509 (0.002491483)
Iteration: [ 6001/ 50000] Loss: 0.001165972 (0.001445780) Physics Loss: 0.000292996 (0.000317503) Data Loss: 0.000641868 (0.000828786) BC Loss: 0.000231109 (0.000299491)
Iteration: [ 7001/ 50000] Loss: 0.001423379 (0.000975281) Physics Loss: 0.000278485 (0.000281843) Data Loss: 0.001017965 (0.000542995) BC Loss: 0.000126929 (0.000150443)
Iteration: [ 8001/ 50000] Loss: 0.001053529 (0.000919709) Physics Loss: 0.000653722 (0.000405850) Data Loss: 0.000317160 (0.000413969) BC Loss: 0.000082647 (0.000099890)
Iteration: [ 9001/ 50000] Loss: 0.002087245 (0.002705987) Physics Loss: 0.001064755 (0.001851798) Data Loss: 0.000757199 (0.000511168) BC Loss: 0.000265291 (0.000343021)
Iteration: [ 10001/ 50000] Loss: 0.000529433 (0.000656713) Physics Loss: 0.000160650 (0.000276859) Data Loss: 0.000314980 (0.000315967) BC Loss: 0.000053804 (0.000063888)
Iteration: [ 11001/ 50000] Loss: 0.000377613 (0.000412380) Physics Loss: 0.000133789 (0.000080756) Data Loss: 0.000195689 (0.000289428) BC Loss: 0.000048136 (0.000042197)
Iteration: [ 12001/ 50000] Loss: 0.000279888 (0.000380037) Physics Loss: 0.000054443 (0.000074833) Data Loss: 0.000180785 (0.000264915) BC Loss: 0.000044660 (0.000040290)
Iteration: [ 13001/ 50000] Loss: 0.000341679 (0.000357849) Physics Loss: 0.000085268 (0.000077957) Data Loss: 0.000220003 (0.000240304) BC Loss: 0.000036408 (0.000039588)
Iteration: [ 14001/ 50000] Loss: 0.000448218 (0.000360817) Physics Loss: 0.000102345 (0.000079283) Data Loss: 0.000306074 (0.000248709) BC Loss: 0.000039799 (0.000032825)
Iteration: [ 15001/ 50000] Loss: 0.000272533 (0.000311834) Physics Loss: 0.000066018 (0.000066296) Data Loss: 0.000172731 (0.000211339) BC Loss: 0.000033784 (0.000034200)
Iteration: [ 16001/ 50000] Loss: 0.000238120 (0.000315288) Physics Loss: 0.000058756 (0.000070502) Data Loss: 0.000143703 (0.000212839) BC Loss: 0.000035661 (0.000031947)
Iteration: [ 17001/ 50000] Loss: 0.000418589 (0.000316255) Physics Loss: 0.000073222 (0.000073601) Data Loss: 0.000320466 (0.000212101) BC Loss: 0.000024901 (0.000030552)
Iteration: [ 18001/ 50000] Loss: 0.000241269 (0.000309437) Physics Loss: 0.000064064 (0.000074983) Data Loss: 0.000143459 (0.000204661) BC Loss: 0.000033745 (0.000029792)
Iteration: [ 19001/ 50000] Loss: 0.000231196 (0.000295417) Physics Loss: 0.000068740 (0.000061112) Data Loss: 0.000142923 (0.000205838) BC Loss: 0.000019532 (0.000028467)
Iteration: [ 20001/ 50000] Loss: 0.000318833 (0.000268500) Physics Loss: 0.000059162 (0.000054753) Data Loss: 0.000238338 (0.000188364) BC Loss: 0.000021333 (0.000025382)
Iteration: [ 21001/ 50000] Loss: 0.000312148 (0.000262511) Physics Loss: 0.000056570 (0.000058834) Data Loss: 0.000227032 (0.000177772) BC Loss: 0.000028546 (0.000025905)
Iteration: [ 22001/ 50000] Loss: 0.000186620 (0.000260711) Physics Loss: 0.000042332 (0.000060141) Data Loss: 0.000111391 (0.000174661) BC Loss: 0.000032897 (0.000025909)
Iteration: [ 23001/ 50000] Loss: 0.000240173 (0.000268906) Physics Loss: 0.000040930 (0.000061117) Data Loss: 0.000177015 (0.000183045) BC Loss: 0.000022228 (0.000024744)
Iteration: [ 24001/ 50000] Loss: 0.000287513 (0.000258606) Physics Loss: 0.000036916 (0.000055313) Data Loss: 0.000225571 (0.000175950) BC Loss: 0.000025025 (0.000027343)
Iteration: [ 25001/ 50000] Loss: 0.000207260 (0.000233574) Physics Loss: 0.000039301 (0.000037695) Data Loss: 0.000145235 (0.000172560) BC Loss: 0.000022724 (0.000023318)
Iteration: [ 26001/ 50000] Loss: 0.000255382 (0.000263254) Physics Loss: 0.000078154 (0.000071206) Data Loss: 0.000156958 (0.000167393) BC Loss: 0.000020270 (0.000024656)
Iteration: [ 27001/ 50000] Loss: 0.000239542 (0.000242097) Physics Loss: 0.000055097 (0.000048702) Data Loss: 0.000158798 (0.000168599) BC Loss: 0.000025647 (0.000024796)
Iteration: [ 28001/ 50000] Loss: 0.000242247 (0.000227806) Physics Loss: 0.000068621 (0.000043346) Data Loss: 0.000152588 (0.000161294) BC Loss: 0.000021038 (0.000023166)
Iteration: [ 29001/ 50000] Loss: 0.000219872 (0.000234653) Physics Loss: 0.000039012 (0.000049347) Data Loss: 0.000146876 (0.000162229) BC Loss: 0.000033984 (0.000023077)
Iteration: [ 30001/ 50000] Loss: 0.000211663 (0.000234121) Physics Loss: 0.000026093 (0.000048308) Data Loss: 0.000159883 (0.000162785) BC Loss: 0.000025687 (0.000023027)
Iteration: [ 31001/ 50000] Loss: 0.000289982 (0.000228052) Physics Loss: 0.000058561 (0.000044936) Data Loss: 0.000211453 (0.000161257) BC Loss: 0.000019968 (0.000021859)
Iteration: [ 32001/ 50000] Loss: 0.000211890 (0.000215507) Physics Loss: 0.000045629 (0.000039202) Data Loss: 0.000144265 (0.000154199) BC Loss: 0.000021996 (0.000022106)
Iteration: [ 33001/ 50000] Loss: 0.000212471 (0.000214695) Physics Loss: 0.000053961 (0.000042916) Data Loss: 0.000133127 (0.000149925) BC Loss: 0.000025383 (0.000021854)
Iteration: [ 34001/ 50000] Loss: 0.000218986 (0.000201616) Physics Loss: 0.000049417 (0.000034055) Data Loss: 0.000144530 (0.000146131) BC Loss: 0.000025038 (0.000021430)
Iteration: [ 35001/ 50000] Loss: 0.000133406 (0.000208700) Physics Loss: 0.000015525 (0.000038143) Data Loss: 0.000093331 (0.000149673) BC Loss: 0.000024550 (0.000020884)
Iteration: [ 36001/ 50000] Loss: 0.000175085 (0.000198660) Physics Loss: 0.000038356 (0.000029753) Data Loss: 0.000117304 (0.000148736) BC Loss: 0.000019425 (0.000020171)
Iteration: [ 37001/ 50000] Loss: 0.000302884 (0.000203824) Physics Loss: 0.000054658 (0.000034690) Data Loss: 0.000225473 (0.000147490) BC Loss: 0.000022752 (0.000021644)
Iteration: [ 38001/ 50000] Loss: 0.000239741 (0.000200345) Physics Loss: 0.000029558 (0.000029087) Data Loss: 0.000186454 (0.000152546) BC Loss: 0.000023728 (0.000018712)
Iteration: [ 39001/ 50000] Loss: 0.000186373 (0.000205306) Physics Loss: 0.000037727 (0.000037231) Data Loss: 0.000130181 (0.000147349) BC Loss: 0.000018464 (0.000020726)
Iteration: [ 40001/ 50000] Loss: 0.000174471 (0.000202701) Physics Loss: 0.000023760 (0.000036654) Data Loss: 0.000129201 (0.000144902) BC Loss: 0.000021511 (0.000021145)
Iteration: [ 41001/ 50000] Loss: 0.000162912 (0.000196236) Physics Loss: 0.000021177 (0.000028328) Data Loss: 0.000120247 (0.000147842) BC Loss: 0.000021488 (0.000020066)
Iteration: [ 42001/ 50000] Loss: 0.000193533 (0.000197097) Physics Loss: 0.000033993 (0.000030959) Data Loss: 0.000140259 (0.000145902) BC Loss: 0.000019280 (0.000020236)
Iteration: [ 43001/ 50000] Loss: 0.000207839 (0.000208236) Physics Loss: 0.000036329 (0.000037497) Data Loss: 0.000152991 (0.000147151) BC Loss: 0.000018519 (0.000023588)
Iteration: [ 44001/ 50000] Loss: 0.000176791 (0.000191548) Physics Loss: 0.000018320 (0.000027500) Data Loss: 0.000142873 (0.000142562) BC Loss: 0.000015598 (0.000021485)
Iteration: [ 45001/ 50000] Loss: 0.000307369 (0.000208912) Physics Loss: 0.000066642 (0.000032655) Data Loss: 0.000212272 (0.000153139) BC Loss: 0.000028455 (0.000023119)
Iteration: [ 46001/ 50000] Loss: 0.000190369 (0.000189811) Physics Loss: 0.000017032 (0.000030931) Data Loss: 0.000154196 (0.000139556) BC Loss: 0.000019141 (0.000019324)
Iteration: [ 47001/ 50000] Loss: 0.000178279 (0.000192155) Physics Loss: 0.000021684 (0.000031024) Data Loss: 0.000137301 (0.000140530) BC Loss: 0.000019294 (0.000020602)
Iteration: [ 48001/ 50000] Loss: 0.000201516 (0.000205428) Physics Loss: 0.000036995 (0.000046254) Data Loss: 0.000144705 (0.000138921) BC Loss: 0.000019816 (0.000020253)
Iteration: [ 49001/ 50000] Loss: 0.000219198 (0.000222733) Physics Loss: 0.000071237 (0.000059399) Data Loss: 0.000126168 (0.000142836) BC Loss: 0.000021792 (0.000020498)
Visualizing the Results
ts, xs, ys = 0.0f0:0.05f0:2.0f0, 0.0f0:0.02f0:2.0f0, 0.0f0:0.02f0:2.0f0
grid = stack([[elem...] for elem in vec(collect(Iterators.product(xs, ys, ts)))])
u_real = reshape(analytical_solution(grid), length(xs), length(ys), length(ts))
grid_normalized = (grid .- minimum(grid)) ./ (maximum(grid) .- minimum(grid))
u_pred = reshape(trained_model(grid_normalized), length(xs), length(ys), length(ts))
u_pred = u_pred .* (max_pde_val - min_pde_val) .+ min_pde_val
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
errs = [abs.(u_pred[:, :, i] .- u_real[:, :, i]) for i in 1:length(ts)]
Colorbar(fig[1, 2]; limits=extrema(stack(errs)))
CairoMakie.record(fig, "pinn_nested_ad.gif", 1:length(ts); framerate=10) do i
ax.title = "Abs. Predictor Error | Time: $(ts[i])"
err = errs[i]
contour!(ax, xs, ys, err; levels=10, linewidth=2)
heatmap!(ax, xs, ys, err)
return fig
end
fig
end
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
This page was generated using Literate.jl.