Skip to content

Training a PINN on 2D PDE

In this tutorial we will go over using a PINN to solve 2D PDEs. We will be using the system from NeuralPDE Tutorials. However, we will be using our custom loss function and use nested AD capabilities of Lux.jl.

This is a demonstration of Lux.jl. For serious use cases of PINNs, please refer to the package: NeuralPDE.jl.

Package Imports

julia
using Lux,
    Optimisers,
    Random,
    Printf,
    Statistics,
    MLUtils,
    OnlineStats,
    CairoMakie,
    Reactant,
    Enzyme

const xdev = reactant_device(; force=true)
const cdev = cpu_device()
Precompiling Reactant...
  13353.3 ms  ? Enzyme
  13891.0 ms  ? Enzyme → EnzymeGPUArraysCoreExt
Info Given Reactant was explicitly requested, output will be shown live 
WARNING: Method definition within_autodiff() in module EnzymeCore at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/EnzymeCore/0ptb3/src/EnzymeCore.jl:619 overwritten in module Enzyme at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Enzyme/nqe7m/src/Enzyme.jl:1561.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
  14838.0 ms  ? Reactant
WARNING: Method definition within_autodiff() in module EnzymeCore at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/EnzymeCore/0ptb3/src/EnzymeCore.jl:619 overwritten in module Enzyme at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Enzyme/nqe7m/src/Enzyme.jl:1561.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
Precompiling Enzyme...
Info Given Enzyme was explicitly requested, output will be shown live 
WARNING: Method definition within_autodiff() in module EnzymeCore at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/EnzymeCore/0ptb3/src/EnzymeCore.jl:619 overwritten in module Enzyme at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Enzyme/nqe7m/src/Enzyme.jl:1561.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
  13120.0 ms  ? Enzyme
WARNING: Method definition within_autodiff() in module EnzymeCore at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/EnzymeCore/0ptb3/src/EnzymeCore.jl:619 overwritten in module Enzyme at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Enzyme/nqe7m/src/Enzyme.jl:1561.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
Precompiling LuxEnzymeExt...
  13596.0 ms  ? Enzyme
    755.4 ms  ? Enzyme → EnzymeChainRulesCoreExt
    901.5 ms  ? Enzyme → EnzymeSpecialFunctionsExt
    953.0 ms  ? Enzyme → EnzymeStaticArraysExt
    733.4 ms  ? Enzyme → EnzymeLogExpFunctionsExt
    723.4 ms  ? Enzyme → EnzymeGPUArraysCoreExt
Info Given LuxEnzymeExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-358a-42117dfcf604 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    719.9 ms  ? Lux → LuxEnzymeExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-358a-42117dfcf604 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling QuadGKEnzymeExt...
  13802.4 ms  ? Enzyme
Info Given QuadGKEnzymeExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-358a-42117dfcf604 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    737.2 ms  ? QuadGK → QuadGKEnzymeExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-358a-42117dfcf604 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling EnzymeLogExpFunctionsExt...
  13262.1 ms  ? Enzyme
Info Given EnzymeLogExpFunctionsExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-358a-42117dfcf604 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    721.1 ms  ? Enzyme → EnzymeLogExpFunctionsExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-358a-42117dfcf604 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling EnzymeChainRulesCoreExt...
  13149.8 ms  ? Enzyme
Info Given EnzymeChainRulesCoreExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-358a-42117dfcf604 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    719.9 ms  ? Enzyme → EnzymeChainRulesCoreExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-358a-42117dfcf604 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling EnzymeSpecialFunctionsExt...
  13044.1 ms  ? Enzyme
    736.4 ms  ? Enzyme → EnzymeLogExpFunctionsExt
Info Given EnzymeSpecialFunctionsExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-358a-42117dfcf604 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    884.8 ms  ? Enzyme → EnzymeSpecialFunctionsExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-358a-42117dfcf604 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling EnzymeGPUArraysCoreExt...
  13514.0 ms  ? Enzyme
Info Given EnzymeGPUArraysCoreExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-358a-42117dfcf604 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    741.5 ms  ? Enzyme → EnzymeGPUArraysCoreExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-358a-42117dfcf604 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling EnzymeStaticArraysExt...
  13760.3 ms  ? Enzyme
Info Given EnzymeStaticArraysExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-358a-42117dfcf604 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    940.2 ms  ? Enzyme → EnzymeStaticArraysExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-358a-42117dfcf604 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling OptimisersReactantExt...
  13113.8 ms  ? Enzyme
    714.5 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    735.9 ms  ? Enzyme → EnzymeChainRulesCoreExt
   1912.5 ms  ? Reactant
    726.8 ms  ? Reactant → ReactantStatisticsExt
Info Given OptimisersReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    791.3 ms  ? Optimisers → OptimisersReactantExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling LuxCoreReactantExt...
  13283.7 ms  ? Enzyme
    724.4 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   1927.7 ms  ? Reactant
Info Given LuxCoreReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    756.7 ms  ? LuxCore → LuxCoreReactantExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling MLDataDevicesReactantExt...
  13698.1 ms  ? Enzyme
    746.6 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   1997.6 ms  ? Reactant
Info Given MLDataDevicesReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    777.5 ms  ? MLDataDevices → MLDataDevicesReactantExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling WeightInitializersReactantExt...
  13238.1 ms  ? Enzyme
    738.0 ms  ? Enzyme → EnzymeLogExpFunctionsExt
    789.2 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    873.7 ms  ? Enzyme → EnzymeSpecialFunctionsExt
   1946.5 ms  ? Reactant
Info Given WeightInitializersReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    742.7 ms  ? Reactant → ReactantStatisticsExt
    748.4 ms  ? WeightInitializers → WeightInitializersReactantExt
    896.0 ms  ? Reactant → ReactantSpecialFunctionsExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling ReactantAbstractFFTsExt...
  13173.0 ms  ? Enzyme
    723.7 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   1927.3 ms  ? Reactant
Info Given ReactantAbstractFFTsExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    727.8 ms  ? Reactant → ReactantAbstractFFTsExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling ReactantOffsetArraysExt...
  12950.3 ms  ? Enzyme
    713.1 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   1925.0 ms  ? Reactant
Info Given ReactantOffsetArraysExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    726.1 ms  ? Reactant → ReactantOffsetArraysExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling ReactantKernelAbstractionsExt...
  13018.4 ms  ? Enzyme
    705.4 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    953.0 ms  ? Enzyme → EnzymeStaticArraysExt
   1977.4 ms  ? Reactant
Info Given ReactantKernelAbstractionsExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    709.7 ms  ? Reactant → ReactantKernelAbstractionsExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling ReactantArrayInterfaceExt...
  12989.6 ms  ? Enzyme
    736.8 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   1970.5 ms  ? Reactant
Info Given ReactantArrayInterfaceExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    715.8 ms  ? Reactant → ReactantArrayInterfaceExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling ReactantSpecialFunctionsExt...
  12882.0 ms  ? Enzyme
    723.7 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    741.2 ms  ? Enzyme → EnzymeLogExpFunctionsExt
    869.0 ms  ? Enzyme → EnzymeSpecialFunctionsExt
   1909.5 ms  ? Reactant
Info Given ReactantSpecialFunctionsExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    996.4 ms  ? Reactant → ReactantSpecialFunctionsExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling ReactantFillArraysExt...
  12937.9 ms  ? Enzyme
    710.3 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   1948.7 ms  ? Reactant
Info Given ReactantFillArraysExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    706.6 ms  ? Reactant → ReactantFillArraysExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling ReactantStatisticsExt...
  13059.9 ms  ? Enzyme
    745.0 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   1912.4 ms  ? Reactant
Info Given ReactantStatisticsExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    743.3 ms  ? Reactant → ReactantStatisticsExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling LuxLibReactantExt...
  13124.1 ms  ? Enzyme
    738.4 ms  ? Enzyme → EnzymeChainRulesCoreExt
    877.2 ms  ? Enzyme → EnzymeSpecialFunctionsExt
    920.7 ms  ? Enzyme → EnzymeStaticArraysExt
    719.9 ms  ? Enzyme → EnzymeLogExpFunctionsExt
    702.9 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   1933.3 ms  ? Reactant
    712.5 ms  ? Reactant → ReactantStatisticsExt
    769.8 ms  ? Reactant → ReactantKernelAbstractionsExt
   1019.4 ms  ? Reactant → ReactantSpecialFunctionsExt
    752.1 ms  ? Reactant → ReactantArrayInterfaceExt
    844.7 ms  ? MLDataDevices → MLDataDevicesReactantExt
    779.6 ms  ? LuxCore → LuxCoreReactantExt
Info Given LuxLibReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    742.7 ms  ? LuxLib → LuxLibReactantExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling ReactantNNlibExt...
  13243.6 ms  ? Enzyme
    734.5 ms  ? Enzyme → EnzymeChainRulesCoreExt
    741.2 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    922.8 ms  ? Enzyme → EnzymeStaticArraysExt
   1934.7 ms  ? Reactant
    705.3 ms  ? Reactant → ReactantStatisticsExt
    709.7 ms  ? Reactant → ReactantKernelAbstractionsExt
Info Given ReactantNNlibExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
   1007.5 ms  ? Reactant → ReactantNNlibExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-0c72-b8a6830a3937 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling LuxReactantExt...
  13183.3 ms  ? Enzyme
    772.0 ms  ? Enzyme → EnzymeChainRulesCoreExt
    876.4 ms  ? Enzyme → EnzymeSpecialFunctionsExt
    933.9 ms  ? Enzyme → EnzymeStaticArraysExt
    746.3 ms  ? Enzyme → EnzymeLogExpFunctionsExt
    754.0 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    709.7 ms  ? Lux → LuxEnzymeExt
   1923.1 ms  ? Reactant
    713.6 ms  ? Reactant → ReactantKernelAbstractionsExt
    723.4 ms  ? Reactant → ReactantStatisticsExt
    921.3 ms  ? Reactant → ReactantSpecialFunctionsExt
    716.8 ms  ? Reactant → ReactantArrayInterfaceExt
    768.9 ms  ? MLDataDevices → MLDataDevicesReactantExt
    778.6 ms  ? LuxCore → LuxCoreReactantExt
    768.0 ms  ? Optimisers → OptimisersReactantExt
    726.9 ms  ? WeightInitializers → WeightInitializersReactantExt
    717.4 ms  ? LuxLib → LuxLibReactantExt
Info Given LuxReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-358a-42117dfcf604 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    722.5 ms  ? Lux → LuxReactantExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-358a-42117dfcf604 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541

Problem Definition

Since Lux supports efficient nested AD upto 2nd order, we will rewrite the problem with first order derivatives, so that we can compute the gradients of the loss using 2nd order AD.

Define the Neural Networks

All the networks take 3 input variables and output a scalar value. Here, we will define a wrapper over the 3 networks, so that we can train them using Training.TrainState.

julia
struct PINN{M} <: AbstractLuxWrapperLayer{:model}
    model::M
end

function PINN(; hidden_dims::Int=32)
    return PINN(
        Chain(
            Dense(3 => hidden_dims, tanh),
            Dense(hidden_dims => hidden_dims, tanh),
            Dense(hidden_dims => hidden_dims, tanh),
            Dense(hidden_dims => 1),
        ),
    )
end

Define the Loss Functions

We will define a custom loss function to compute the loss using 2nd order AD. For that, first we'll need to define the derivatives of our model:

julia
function ∂u_∂t(model::StatefulLuxLayer, xyt::AbstractArray)
    return Enzyme.gradient(Enzyme.Reverse, sum  model, xyt)[1][3, :]
end

function ∂u_∂x(model::StatefulLuxLayer, xyt::AbstractArray)
    return Enzyme.gradient(Enzyme.Reverse, sum  model, xyt)[1][1, :]
end

function ∂u_∂y(model::StatefulLuxLayer, xyt::AbstractArray)
    return Enzyme.gradient(Enzyme.Reverse, sum  model, xyt)[1][2, :]
end

function ∂²u_∂x²(model::StatefulLuxLayer, xyt::AbstractArray)
    return Enzyme.gradient(Enzyme.Reverse, sum  ∂u_∂x, Enzyme.Const(model), xyt)[2][1, :]
end

function ∂²u_∂y²(model::StatefulLuxLayer, xyt::AbstractArray)
    return Enzyme.gradient(Enzyme.Reverse, sum  ∂u_∂y, Enzyme.Const(model), xyt)[2][2, :]
end

We will use the following loss function

julia
function physics_informed_loss_function(model::StatefulLuxLayer, xyt::AbstractArray)
    return mean(abs2, ∂u_∂t(model, xyt) .- ∂²u_∂x²(model, xyt) .- ∂²u_∂y²(model, xyt))
end

Additionally, we need to compute the loss with respect to the boundary conditions.

julia
function mse_loss_function(
    model::StatefulLuxLayer, target::AbstractArray, xyt::AbstractArray
)
    return MSELoss()(model(xyt), target)
end

function loss_function(model, ps, st, (xyt, target_data, xyt_bc, target_bc))
    smodel = StatefulLuxLayer(model, ps, st)
    physics_loss = physics_informed_loss_function(smodel, xyt)
    data_loss = mse_loss_function(smodel, target_data, xyt)
    bc_loss = mse_loss_function(smodel, target_bc, xyt_bc)
    loss = physics_loss + data_loss + bc_loss
    return loss, smodel.st, (; physics_loss, data_loss, bc_loss)
end

Generate the Data

We will generate some random data to train the model on. We will take data on a square spatial and temporal domain x[0,2], y[0,2], and t[0,2]. Typically, you want to be smarter about the sampling process, but for the sake of simplicity, we will skip that.

julia
analytical_solution(x, y, t) = @. exp(x + y) * cos(x + y + 4t)
analytical_solution(xyt) = analytical_solution(xyt[1, :], xyt[2, :], xyt[3, :])
julia
begin
    grid_len = 16

    grid = range(0.0f0, 2.0f0; length=grid_len)
    xyt = stack([[elem...] for elem in vec(collect(Iterators.product(grid, grid, grid)))])

    target_data = reshape(analytical_solution(xyt), 1, :)

    bc_len = 512

    x = collect(range(0.0f0, 2.0f0; length=bc_len))
    y = collect(range(0.0f0, 2.0f0; length=bc_len))
    t = collect(range(0.0f0, 2.0f0; length=bc_len))

    xyt_bc = hcat(
        stack((x, y, zeros(Float32, bc_len)); dims=1),
        stack((zeros(Float32, bc_len), y, t); dims=1),
        stack((ones(Float32, bc_len) .* 2, y, t); dims=1),
        stack((x, zeros(Float32, bc_len), t); dims=1),
        stack((x, ones(Float32, bc_len) .* 2, t); dims=1),
    )
    target_bc = reshape(analytical_solution(xyt_bc), 1, :)

    min_target_bc, max_target_bc = extrema(target_bc)
    min_data, max_data = extrema(target_data)
    min_pde_val, max_pde_val = min(min_data, min_target_bc), max(max_data, max_target_bc)

    xyt = (xyt .- minimum(xyt)) ./ (maximum(xyt) .- minimum(xyt))
    xyt_bc = (xyt_bc .- minimum(xyt_bc)) ./ (maximum(xyt_bc) .- minimum(xyt_bc))
    target_bc = (target_bc .- min_pde_val) ./ (max_pde_val - min_pde_val)
    target_data = (target_data .- min_pde_val) ./ (max_pde_val - min_pde_val)
end

Training

julia
function train_model(
    xyt,
    target_data,
    xyt_bc,
    target_bc;
    seed::Int=0,
    maxiters::Int=50000,
    hidden_dims::Int=128,
)
    rng = Random.default_rng()
    Random.seed!(rng, seed)

    pinn = PINN(; hidden_dims)
    ps, st = Lux.setup(rng, pinn) |> xdev

    bc_dataloader =
        DataLoader((xyt_bc, target_bc); batchsize=128, shuffle=true, partial=false) |> xdev
    pde_dataloader =
        DataLoader((xyt, target_data); batchsize=128, shuffle=true, partial=false) |> xdev

    train_state = Training.TrainState(pinn, ps, st, Adam(0.005f0))

    lr = i -> i < 5000 ? 0.005f0 : (i < 10000 ? 0.0005f0 : 0.00005f0)

    total_loss_tracker, physics_loss_tracker, data_loss_tracker, bc_loss_tracker = ntuple(
        _ -> OnlineStats.CircBuff(Float32, 32; rev=true), 4
    )

    iter = 1
    for ((xyt_batch, target_data_batch), (xyt_bc_batch, target_bc_batch)) in
        zip(Iterators.cycle(pde_dataloader), Iterators.cycle(bc_dataloader))
        Optimisers.adjust!(train_state, lr(iter))

        _, loss, stats, train_state = Training.single_train_step!(
            AutoEnzyme(),
            loss_function,
            (xyt_batch, target_data_batch, xyt_bc_batch, target_bc_batch),
            train_state;
            return_gradients=Val(false),
        )

        fit!(total_loss_tracker, Float32(loss))
        fit!(physics_loss_tracker, Float32(stats.physics_loss))
        fit!(data_loss_tracker, Float32(stats.data_loss))
        fit!(bc_loss_tracker, Float32(stats.bc_loss))

        mean_loss = mean(OnlineStats.value(total_loss_tracker))
        mean_physics_loss = mean(OnlineStats.value(physics_loss_tracker))
        mean_data_loss = mean(OnlineStats.value(data_loss_tracker))
        mean_bc_loss = mean(OnlineStats.value(bc_loss_tracker))

        isnan(loss) && throw(ArgumentError("NaN Loss Detected"))

        if iter % 1000 == 1 || iter == maxiters
            @printf(
                "Iteration: [%6d/%6d] \t Loss: %.9f (%.9f) \t Physics Loss: %.9f \
                 (%.9f) \t Data Loss: %.9f (%.9f) \t BC \
                 Loss: %.9f (%.9f)\n",
                iter,
                maxiters,
                loss,
                mean_loss,
                stats.physics_loss,
                mean_physics_loss,
                stats.data_loss,
                mean_data_loss,
                stats.bc_loss,
                mean_bc_loss
            )
        end

        iter += 1
        iter  maxiters && break
    end

    return StatefulLuxLayer(pinn, cdev(train_state.parameters), cdev(train_state.states))
end

trained_model = train_model(xyt, target_data, xyt_bc, target_bc)
AssertionError("Could not find registered platform with name: \"cuda\". Available platform names are: ")
Iteration: [     1/ 50000] 	 Loss: 20.523933411 (20.523933411) 	 Physics Loss: 16.931318283 (16.931318283) 	 Data Loss: 2.007483006 (2.007483006) 	 BC Loss: 1.585133195 (1.585133195)
Iteration: [  1001/ 50000] 	 Loss: 0.017368603 (0.019241145) 	 Physics Loss: 0.000384357 (0.000523635) 	 Data Loss: 0.005318487 (0.007538572) 	 BC Loss: 0.011665760 (0.011178939)
Iteration: [  2001/ 50000] 	 Loss: 0.015431654 (0.018665703) 	 Physics Loss: 0.001248552 (0.001662074) 	 Data Loss: 0.004322530 (0.006408241) 	 BC Loss: 0.009860572 (0.010595390)
Iteration: [  3001/ 50000] 	 Loss: 0.015749799 (0.015216216) 	 Physics Loss: 0.000569920 (0.001279173) 	 Data Loss: 0.004014936 (0.004232483) 	 BC Loss: 0.011164943 (0.009704558)
Iteration: [  4001/ 50000] 	 Loss: 0.009718472 (0.008712115) 	 Physics Loss: 0.002387779 (0.003379629) 	 Data Loss: 0.003175758 (0.002104492) 	 BC Loss: 0.004154935 (0.003227993)
Iteration: [  5001/ 50000] 	 Loss: 0.004535029 (0.007336445) 	 Physics Loss: 0.000547516 (0.003038818) 	 Data Loss: 0.002077004 (0.001806144) 	 BC Loss: 0.001910509 (0.002491483)
Iteration: [  6001/ 50000] 	 Loss: 0.001165972 (0.001445780) 	 Physics Loss: 0.000292996 (0.000317503) 	 Data Loss: 0.000641868 (0.000828786) 	 BC Loss: 0.000231109 (0.000299491)
Iteration: [  7001/ 50000] 	 Loss: 0.001423379 (0.000975281) 	 Physics Loss: 0.000278485 (0.000281843) 	 Data Loss: 0.001017965 (0.000542995) 	 BC Loss: 0.000126929 (0.000150443)
Iteration: [  8001/ 50000] 	 Loss: 0.001053529 (0.000919709) 	 Physics Loss: 0.000653722 (0.000405850) 	 Data Loss: 0.000317160 (0.000413969) 	 BC Loss: 0.000082647 (0.000099890)
Iteration: [  9001/ 50000] 	 Loss: 0.002087245 (0.002705987) 	 Physics Loss: 0.001064755 (0.001851798) 	 Data Loss: 0.000757199 (0.000511168) 	 BC Loss: 0.000265291 (0.000343021)
Iteration: [ 10001/ 50000] 	 Loss: 0.000529433 (0.000656713) 	 Physics Loss: 0.000160650 (0.000276859) 	 Data Loss: 0.000314980 (0.000315967) 	 BC Loss: 0.000053804 (0.000063888)
Iteration: [ 11001/ 50000] 	 Loss: 0.000377613 (0.000412380) 	 Physics Loss: 0.000133789 (0.000080756) 	 Data Loss: 0.000195689 (0.000289428) 	 BC Loss: 0.000048136 (0.000042197)
Iteration: [ 12001/ 50000] 	 Loss: 0.000279888 (0.000380037) 	 Physics Loss: 0.000054443 (0.000074833) 	 Data Loss: 0.000180785 (0.000264915) 	 BC Loss: 0.000044660 (0.000040290)
Iteration: [ 13001/ 50000] 	 Loss: 0.000341679 (0.000357849) 	 Physics Loss: 0.000085268 (0.000077957) 	 Data Loss: 0.000220003 (0.000240304) 	 BC Loss: 0.000036408 (0.000039588)
Iteration: [ 14001/ 50000] 	 Loss: 0.000448218 (0.000360817) 	 Physics Loss: 0.000102345 (0.000079283) 	 Data Loss: 0.000306074 (0.000248709) 	 BC Loss: 0.000039799 (0.000032825)
Iteration: [ 15001/ 50000] 	 Loss: 0.000272533 (0.000311834) 	 Physics Loss: 0.000066018 (0.000066296) 	 Data Loss: 0.000172731 (0.000211339) 	 BC Loss: 0.000033784 (0.000034200)
Iteration: [ 16001/ 50000] 	 Loss: 0.000238120 (0.000315288) 	 Physics Loss: 0.000058756 (0.000070502) 	 Data Loss: 0.000143703 (0.000212839) 	 BC Loss: 0.000035661 (0.000031947)
Iteration: [ 17001/ 50000] 	 Loss: 0.000418589 (0.000316255) 	 Physics Loss: 0.000073222 (0.000073601) 	 Data Loss: 0.000320466 (0.000212101) 	 BC Loss: 0.000024901 (0.000030552)
Iteration: [ 18001/ 50000] 	 Loss: 0.000241269 (0.000309437) 	 Physics Loss: 0.000064064 (0.000074983) 	 Data Loss: 0.000143459 (0.000204661) 	 BC Loss: 0.000033745 (0.000029792)
Iteration: [ 19001/ 50000] 	 Loss: 0.000231196 (0.000295417) 	 Physics Loss: 0.000068740 (0.000061112) 	 Data Loss: 0.000142923 (0.000205838) 	 BC Loss: 0.000019532 (0.000028467)
Iteration: [ 20001/ 50000] 	 Loss: 0.000318833 (0.000268500) 	 Physics Loss: 0.000059162 (0.000054753) 	 Data Loss: 0.000238338 (0.000188364) 	 BC Loss: 0.000021333 (0.000025382)
Iteration: [ 21001/ 50000] 	 Loss: 0.000312148 (0.000262511) 	 Physics Loss: 0.000056570 (0.000058834) 	 Data Loss: 0.000227032 (0.000177772) 	 BC Loss: 0.000028546 (0.000025905)
Iteration: [ 22001/ 50000] 	 Loss: 0.000186620 (0.000260711) 	 Physics Loss: 0.000042332 (0.000060141) 	 Data Loss: 0.000111391 (0.000174661) 	 BC Loss: 0.000032897 (0.000025909)
Iteration: [ 23001/ 50000] 	 Loss: 0.000240173 (0.000268906) 	 Physics Loss: 0.000040930 (0.000061117) 	 Data Loss: 0.000177015 (0.000183045) 	 BC Loss: 0.000022228 (0.000024744)
Iteration: [ 24001/ 50000] 	 Loss: 0.000287513 (0.000258606) 	 Physics Loss: 0.000036916 (0.000055313) 	 Data Loss: 0.000225571 (0.000175950) 	 BC Loss: 0.000025025 (0.000027343)
Iteration: [ 25001/ 50000] 	 Loss: 0.000207260 (0.000233574) 	 Physics Loss: 0.000039301 (0.000037695) 	 Data Loss: 0.000145235 (0.000172560) 	 BC Loss: 0.000022724 (0.000023318)
Iteration: [ 26001/ 50000] 	 Loss: 0.000255382 (0.000263254) 	 Physics Loss: 0.000078154 (0.000071206) 	 Data Loss: 0.000156958 (0.000167393) 	 BC Loss: 0.000020270 (0.000024656)
Iteration: [ 27001/ 50000] 	 Loss: 0.000239542 (0.000242097) 	 Physics Loss: 0.000055097 (0.000048702) 	 Data Loss: 0.000158798 (0.000168599) 	 BC Loss: 0.000025647 (0.000024796)
Iteration: [ 28001/ 50000] 	 Loss: 0.000242247 (0.000227806) 	 Physics Loss: 0.000068621 (0.000043346) 	 Data Loss: 0.000152588 (0.000161294) 	 BC Loss: 0.000021038 (0.000023166)
Iteration: [ 29001/ 50000] 	 Loss: 0.000219872 (0.000234653) 	 Physics Loss: 0.000039012 (0.000049347) 	 Data Loss: 0.000146876 (0.000162229) 	 BC Loss: 0.000033984 (0.000023077)
Iteration: [ 30001/ 50000] 	 Loss: 0.000211663 (0.000234121) 	 Physics Loss: 0.000026093 (0.000048308) 	 Data Loss: 0.000159883 (0.000162785) 	 BC Loss: 0.000025687 (0.000023027)
Iteration: [ 31001/ 50000] 	 Loss: 0.000289982 (0.000228052) 	 Physics Loss: 0.000058561 (0.000044936) 	 Data Loss: 0.000211453 (0.000161257) 	 BC Loss: 0.000019968 (0.000021859)
Iteration: [ 32001/ 50000] 	 Loss: 0.000211890 (0.000215507) 	 Physics Loss: 0.000045629 (0.000039202) 	 Data Loss: 0.000144265 (0.000154199) 	 BC Loss: 0.000021996 (0.000022106)
Iteration: [ 33001/ 50000] 	 Loss: 0.000212471 (0.000214695) 	 Physics Loss: 0.000053961 (0.000042916) 	 Data Loss: 0.000133127 (0.000149925) 	 BC Loss: 0.000025383 (0.000021854)
Iteration: [ 34001/ 50000] 	 Loss: 0.000218986 (0.000201616) 	 Physics Loss: 0.000049417 (0.000034055) 	 Data Loss: 0.000144530 (0.000146131) 	 BC Loss: 0.000025038 (0.000021430)
Iteration: [ 35001/ 50000] 	 Loss: 0.000133406 (0.000208700) 	 Physics Loss: 0.000015525 (0.000038143) 	 Data Loss: 0.000093331 (0.000149673) 	 BC Loss: 0.000024550 (0.000020884)
Iteration: [ 36001/ 50000] 	 Loss: 0.000175085 (0.000198660) 	 Physics Loss: 0.000038356 (0.000029753) 	 Data Loss: 0.000117304 (0.000148736) 	 BC Loss: 0.000019425 (0.000020171)
Iteration: [ 37001/ 50000] 	 Loss: 0.000302884 (0.000203824) 	 Physics Loss: 0.000054658 (0.000034690) 	 Data Loss: 0.000225473 (0.000147490) 	 BC Loss: 0.000022752 (0.000021644)
Iteration: [ 38001/ 50000] 	 Loss: 0.000239741 (0.000200345) 	 Physics Loss: 0.000029558 (0.000029087) 	 Data Loss: 0.000186454 (0.000152546) 	 BC Loss: 0.000023728 (0.000018712)
Iteration: [ 39001/ 50000] 	 Loss: 0.000186373 (0.000205306) 	 Physics Loss: 0.000037727 (0.000037231) 	 Data Loss: 0.000130181 (0.000147349) 	 BC Loss: 0.000018464 (0.000020726)
Iteration: [ 40001/ 50000] 	 Loss: 0.000174471 (0.000202701) 	 Physics Loss: 0.000023760 (0.000036654) 	 Data Loss: 0.000129201 (0.000144902) 	 BC Loss: 0.000021511 (0.000021145)
Iteration: [ 41001/ 50000] 	 Loss: 0.000162912 (0.000196236) 	 Physics Loss: 0.000021177 (0.000028328) 	 Data Loss: 0.000120247 (0.000147842) 	 BC Loss: 0.000021488 (0.000020066)
Iteration: [ 42001/ 50000] 	 Loss: 0.000193533 (0.000197097) 	 Physics Loss: 0.000033993 (0.000030959) 	 Data Loss: 0.000140259 (0.000145902) 	 BC Loss: 0.000019280 (0.000020236)
Iteration: [ 43001/ 50000] 	 Loss: 0.000207839 (0.000208236) 	 Physics Loss: 0.000036329 (0.000037497) 	 Data Loss: 0.000152991 (0.000147151) 	 BC Loss: 0.000018519 (0.000023588)
Iteration: [ 44001/ 50000] 	 Loss: 0.000176791 (0.000191548) 	 Physics Loss: 0.000018320 (0.000027500) 	 Data Loss: 0.000142873 (0.000142562) 	 BC Loss: 0.000015598 (0.000021485)
Iteration: [ 45001/ 50000] 	 Loss: 0.000307369 (0.000208912) 	 Physics Loss: 0.000066642 (0.000032655) 	 Data Loss: 0.000212272 (0.000153139) 	 BC Loss: 0.000028455 (0.000023119)
Iteration: [ 46001/ 50000] 	 Loss: 0.000190369 (0.000189811) 	 Physics Loss: 0.000017032 (0.000030931) 	 Data Loss: 0.000154196 (0.000139556) 	 BC Loss: 0.000019141 (0.000019324)
Iteration: [ 47001/ 50000] 	 Loss: 0.000178279 (0.000192155) 	 Physics Loss: 0.000021684 (0.000031024) 	 Data Loss: 0.000137301 (0.000140530) 	 BC Loss: 0.000019294 (0.000020602)
Iteration: [ 48001/ 50000] 	 Loss: 0.000201516 (0.000205428) 	 Physics Loss: 0.000036995 (0.000046254) 	 Data Loss: 0.000144705 (0.000138921) 	 BC Loss: 0.000019816 (0.000020253)
Iteration: [ 49001/ 50000] 	 Loss: 0.000219198 (0.000222733) 	 Physics Loss: 0.000071237 (0.000059399) 	 Data Loss: 0.000126168 (0.000142836) 	 BC Loss: 0.000021792 (0.000020498)

Visualizing the Results

julia
ts, xs, ys = 0.0f0:0.05f0:2.0f0, 0.0f0:0.02f0:2.0f0, 0.0f0:0.02f0:2.0f0
grid = stack([[elem...] for elem in vec(collect(Iterators.product(xs, ys, ts)))])

u_real = reshape(analytical_solution(grid), length(xs), length(ys), length(ts))

grid_normalized = (grid .- minimum(grid)) ./ (maximum(grid) .- minimum(grid))
u_pred = reshape(trained_model(grid_normalized), length(xs), length(ys), length(ts))
u_pred = u_pred .* (max_pde_val - min_pde_val) .+ min_pde_val

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
    errs = [abs.(u_pred[:, :, i] .- u_real[:, :, i]) for i in 1:length(ts)]
    Colorbar(fig[1, 2]; limits=extrema(stack(errs)))

    CairoMakie.record(fig, "pinn_nested_ad.gif", 1:length(ts); framerate=10) do i
        ax.title = "Abs. Predictor Error | Time: $(ts[i])"
        err = errs[i]
        contour!(ax, xs, ys, err; levels=10, linewidth=2)
        heatmap!(ax, xs, ys, err)
        return fig
    end

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6

This page was generated using Literate.jl.