Training a PINN on 2D PDE
In this tutorial we will go over using a PINN to solve 2D PDEs. We will be using the system from NeuralPDE Tutorials. However, we will be using our custom loss function and use nested AD capabilities of Lux.jl.
This is a demonstration of Lux.jl. For serious use cases of PINNs, please refer to the package: NeuralPDE.jl.
Package Imports
using Lux,
Optimisers,
Random,
Printf,
Statistics,
MLUtils,
OnlineStats,
CairoMakie,
Reactant,
Enzyme
const xdev = reactant_device(; force=true)
const cdev = cpu_device()Problem Definition
Since Lux supports efficient nested AD upto 2nd order, we will rewrite the problem with first order derivatives, so that we can compute the gradients of the loss using 2nd order AD.
Define the Neural Networks
All the networks take 3 input variables and output a scalar value. Here, we will define a wrapper over the 3 networks, so that we can train them using Training.TrainState.
struct PINN{M} <: AbstractLuxWrapperLayer{:model}
model::M
end
function PINN(; hidden_dims::Int=32)
return PINN(
Chain(
Dense(3 => hidden_dims, tanh),
Dense(hidden_dims => hidden_dims, tanh),
Dense(hidden_dims => hidden_dims, tanh),
Dense(hidden_dims => 1),
),
)
endDefine the Loss Functions
We will define a custom loss function to compute the loss using 2nd order AD. For that, first we'll need to define the derivatives of our model:
function ∂u_∂t(model::StatefulLuxLayer, xyt::AbstractArray)
return Enzyme.gradient(Enzyme.Reverse, sum ∘ model, xyt)[1][3, :]
end
function ∂u_∂x(model::StatefulLuxLayer, xyt::AbstractArray)
return Enzyme.gradient(Enzyme.Reverse, sum ∘ model, xyt)[1][1, :]
end
function ∂u_∂y(model::StatefulLuxLayer, xyt::AbstractArray)
return Enzyme.gradient(Enzyme.Reverse, sum ∘ model, xyt)[1][2, :]
end
function ∂²u_∂x²(model::StatefulLuxLayer, xyt::AbstractArray)
return Enzyme.gradient(Enzyme.Reverse, sum ∘ ∂u_∂x, Enzyme.Const(model), xyt)[2][1, :]
end
function ∂²u_∂y²(model::StatefulLuxLayer, xyt::AbstractArray)
return Enzyme.gradient(Enzyme.Reverse, sum ∘ ∂u_∂y, Enzyme.Const(model), xyt)[2][2, :]
endWe will use the following loss function
function physics_informed_loss_function(model::StatefulLuxLayer, xyt::AbstractArray)
return mean(abs2, ∂u_∂t(model, xyt) .- ∂²u_∂x²(model, xyt) .- ∂²u_∂y²(model, xyt))
endAdditionally, we need to compute the loss with respect to the boundary conditions.
function mse_loss_function(
model::StatefulLuxLayer, target::AbstractArray, xyt::AbstractArray
)
return MSELoss()(model(xyt), target)
end
function loss_function(model, ps, st, (xyt, target_data, xyt_bc, target_bc))
smodel = StatefulLuxLayer(model, ps, st)
physics_loss = physics_informed_loss_function(smodel, xyt)
data_loss = mse_loss_function(smodel, target_data, xyt)
bc_loss = mse_loss_function(smodel, target_bc, xyt_bc)
loss = physics_loss + data_loss + bc_loss
return loss, smodel.st, (; physics_loss, data_loss, bc_loss)
endGenerate the Data
We will generate some random data to train the model on. We will take data on a square spatial and temporal domain
analytical_solution(x, y, t) = @. exp(x + y) * cos(x + y + 4t)
analytical_solution(xyt) = analytical_solution(xyt[1, :], xyt[2, :], xyt[3, :])begin
grid_len = 16
grid = range(0.0f0, 2.0f0; length=grid_len)
xyt = stack([[elem...] for elem in vec(collect(Iterators.product(grid, grid, grid)))])
target_data = reshape(analytical_solution(xyt), 1, :)
bc_len = 512
x = collect(range(0.0f0, 2.0f0; length=bc_len))
y = collect(range(0.0f0, 2.0f0; length=bc_len))
t = collect(range(0.0f0, 2.0f0; length=bc_len))
xyt_bc = hcat(
stack((x, y, zeros(Float32, bc_len)); dims=1),
stack((zeros(Float32, bc_len), y, t); dims=1),
stack((ones(Float32, bc_len) .* 2, y, t); dims=1),
stack((x, zeros(Float32, bc_len), t); dims=1),
stack((x, ones(Float32, bc_len) .* 2, t); dims=1),
)
target_bc = reshape(analytical_solution(xyt_bc), 1, :)
min_target_bc, max_target_bc = extrema(target_bc)
min_data, max_data = extrema(target_data)
min_pde_val, max_pde_val = min(min_data, min_target_bc), max(max_data, max_target_bc)
xyt = (xyt .- minimum(xyt)) ./ (maximum(xyt) .- minimum(xyt))
xyt_bc = (xyt_bc .- minimum(xyt_bc)) ./ (maximum(xyt_bc) .- minimum(xyt_bc))
target_bc = (target_bc .- min_pde_val) ./ (max_pde_val - min_pde_val)
target_data = (target_data .- min_pde_val) ./ (max_pde_val - min_pde_val)
endTraining
function train_model(
xyt,
target_data,
xyt_bc,
target_bc;
seed::Int=0,
maxiters::Int=50000,
hidden_dims::Int=128,
)
rng = Random.default_rng()
Random.seed!(rng, seed)
pinn = PINN(; hidden_dims)
ps, st = Lux.setup(rng, pinn) |> xdev
bc_dataloader =
DataLoader((xyt_bc, target_bc); batchsize=128, shuffle=true, partial=false) |> xdev
pde_dataloader =
DataLoader((xyt, target_data); batchsize=128, shuffle=true, partial=false) |> xdev
train_state = Training.TrainState(pinn, ps, st, Adam(0.005f0))
lr = i -> i < 5000 ? 0.005f0 : (i < 10000 ? 0.0005f0 : 0.00005f0)
total_loss_tracker, physics_loss_tracker, data_loss_tracker, bc_loss_tracker = ntuple(
_ -> OnlineStats.CircBuff(Float32, 32; rev=true), 4
)
iter = 1
for ((xyt_batch, target_data_batch), (xyt_bc_batch, target_bc_batch)) in
zip(Iterators.cycle(pde_dataloader), Iterators.cycle(bc_dataloader))
Optimisers.adjust!(train_state, lr(iter))
_, loss, stats, train_state = Training.single_train_step!(
AutoEnzyme(),
loss_function,
(xyt_batch, target_data_batch, xyt_bc_batch, target_bc_batch),
train_state;
return_gradients=Val(false),
)
fit!(total_loss_tracker, Float32(loss))
fit!(physics_loss_tracker, Float32(stats.physics_loss))
fit!(data_loss_tracker, Float32(stats.data_loss))
fit!(bc_loss_tracker, Float32(stats.bc_loss))
mean_loss = mean(OnlineStats.value(total_loss_tracker))
mean_physics_loss = mean(OnlineStats.value(physics_loss_tracker))
mean_data_loss = mean(OnlineStats.value(data_loss_tracker))
mean_bc_loss = mean(OnlineStats.value(bc_loss_tracker))
isnan(loss) && throw(ArgumentError("NaN Loss Detected"))
if iter % 1000 == 1 || iter == maxiters
@printf(
"Iteration: [%6d/%6d] \t Loss: %.9f (%.9f) \t Physics Loss: %.9f \
(%.9f) \t Data Loss: %.9f (%.9f) \t BC \
Loss: %.9f (%.9f)\n",
iter,
maxiters,
loss,
mean_loss,
stats.physics_loss,
mean_physics_loss,
stats.data_loss,
mean_data_loss,
stats.bc_loss,
mean_bc_loss
)
end
iter += 1
iter ≥ maxiters && break
end
return StatefulLuxLayer(pinn, cdev(train_state.parameters), cdev(train_state.states))
end
trained_model = train_model(xyt, target_data, xyt_bc, target_bc)WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1760849837.016775 103616 service.cc:158] XLA service 0x188dd590 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1760849837.016840 103616 service.cc:166] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
I0000 00:00:1760849837.017728 103616 se_gpu_pjrt_client.cc:1339] Using BFC allocator.
I0000 00:00:1760849837.017763 103616 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1760849837.017809 103616 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1760849837.029789 103616 cuda_dnn.cc:463] Loaded cuDNN version 91200
Iteration: [ 1/ 50000] Loss: 20.529172897 (20.529172897) Physics Loss: 16.935272217 (16.935272217) Data Loss: 2.008198738 (2.008198738) BC Loss: 1.585703492 (1.585703492)
Iteration: [ 1001/ 50000] Loss: 0.017331716 (0.019193068) Physics Loss: 0.000387317 (0.000485731) Data Loss: 0.005293047 (0.007528459) BC Loss: 0.011651352 (0.011178879)
Iteration: [ 2001/ 50000] Loss: 0.015900077 (0.018633004) Physics Loss: 0.001465373 (0.001504831) Data Loss: 0.004471334 (0.006471491) BC Loss: 0.009963371 (0.010656682)
Iteration: [ 3001/ 50000] Loss: 0.015973251 (0.015203977) Physics Loss: 0.000725899 (0.001278740) Data Loss: 0.004123617 (0.004273000) BC Loss: 0.011123735 (0.009652236)
Iteration: [ 4001/ 50000] Loss: 0.012695793 (0.012816982) Physics Loss: 0.004452500 (0.005652803) Data Loss: 0.002984334 (0.002828420) BC Loss: 0.005258959 (0.004335761)
Iteration: [ 5001/ 50000] Loss: 0.004897796 (0.008685599) Physics Loss: 0.001102662 (0.003582127) Data Loss: 0.002188099 (0.002333902) BC Loss: 0.001607035 (0.002769568)
Iteration: [ 6001/ 50000] Loss: 0.001092522 (0.001332408) Physics Loss: 0.000249394 (0.000307639) Data Loss: 0.000594520 (0.000753054) BC Loss: 0.000248608 (0.000271715)
Iteration: [ 7001/ 50000] Loss: 0.001504518 (0.001043316) Physics Loss: 0.000465222 (0.000412990) Data Loss: 0.000936627 (0.000496557) BC Loss: 0.000102668 (0.000133769)
Iteration: [ 8001/ 50000] Loss: 0.000930153 (0.001334653) Physics Loss: 0.000537639 (0.000803167) Data Loss: 0.000305875 (0.000412027) BC Loss: 0.000086639 (0.000119459)
Iteration: [ 9001/ 50000] Loss: 0.001518591 (0.001497649) Physics Loss: 0.000606798 (0.000906332) Data Loss: 0.000701658 (0.000386005) BC Loss: 0.000210135 (0.000205312)
Iteration: [ 10001/ 50000] Loss: 0.001146152 (0.001166578) Physics Loss: 0.000730612 (0.000714821) Data Loss: 0.000294488 (0.000349879) BC Loss: 0.000121052 (0.000101878)
Iteration: [ 11001/ 50000] Loss: 0.000303401 (0.000381006) Physics Loss: 0.000068595 (0.000060912) Data Loss: 0.000198215 (0.000283848) BC Loss: 0.000036590 (0.000036246)
Iteration: [ 12001/ 50000] Loss: 0.000254626 (0.000349781) Physics Loss: 0.000046110 (0.000057781) Data Loss: 0.000169172 (0.000256495) BC Loss: 0.000039345 (0.000035506)
Iteration: [ 13001/ 50000] Loss: 0.000286091 (0.000333204) Physics Loss: 0.000043883 (0.000066703) Data Loss: 0.000211084 (0.000233484) BC Loss: 0.000031124 (0.000033017)
Iteration: [ 14001/ 50000] Loss: 0.000349683 (0.000337811) Physics Loss: 0.000039542 (0.000065278) Data Loss: 0.000278752 (0.000243441) BC Loss: 0.000031389 (0.000029092)
Iteration: [ 15001/ 50000] Loss: 0.000285819 (0.000308746) Physics Loss: 0.000090523 (0.000073466) Data Loss: 0.000161379 (0.000204055) BC Loss: 0.000033917 (0.000031224)
Iteration: [ 16001/ 50000] Loss: 0.000209424 (0.000295532) Physics Loss: 0.000045663 (0.000058594) Data Loss: 0.000136078 (0.000207966) BC Loss: 0.000027683 (0.000028972)
Iteration: [ 17001/ 50000] Loss: 0.000402348 (0.000296749) Physics Loss: 0.000067907 (0.000061294) Data Loss: 0.000313147 (0.000209142) BC Loss: 0.000021294 (0.000026313)
Iteration: [ 18001/ 50000] Loss: 0.000217428 (0.000283468) Physics Loss: 0.000050234 (0.000058343) Data Loss: 0.000134317 (0.000199132) BC Loss: 0.000032878 (0.000025993)
Iteration: [ 19001/ 50000] Loss: 0.000218901 (0.000288576) Physics Loss: 0.000058156 (0.000058022) Data Loss: 0.000140315 (0.000205322) BC Loss: 0.000020430 (0.000025232)
Iteration: [ 20001/ 50000] Loss: 0.000294953 (0.000269254) Physics Loss: 0.000046109 (0.000060157) Data Loss: 0.000229066 (0.000184630) BC Loss: 0.000019777 (0.000024467)
Iteration: [ 21001/ 50000] Loss: 0.000300445 (0.000257151) Physics Loss: 0.000036583 (0.000056973) Data Loss: 0.000235420 (0.000175423) BC Loss: 0.000028442 (0.000024755)
Iteration: [ 22001/ 50000] Loss: 0.000169547 (0.000251041) Physics Loss: 0.000034612 (0.000055424) Data Loss: 0.000109064 (0.000171932) BC Loss: 0.000025871 (0.000023685)
Iteration: [ 23001/ 50000] Loss: 0.000234877 (0.000253784) Physics Loss: 0.000041697 (0.000050783) Data Loss: 0.000173086 (0.000180385) BC Loss: 0.000020094 (0.000022616)
Iteration: [ 24001/ 50000] Loss: 0.000277989 (0.000240709) Physics Loss: 0.000029074 (0.000041165) Data Loss: 0.000224187 (0.000174445) BC Loss: 0.000024728 (0.000025099)
Iteration: [ 25001/ 50000] Loss: 0.000204744 (0.000232048) Physics Loss: 0.000031281 (0.000037916) Data Loss: 0.000150238 (0.000172505) BC Loss: 0.000023225 (0.000021627)
Iteration: [ 26001/ 50000] Loss: 0.000201997 (0.000233869) Physics Loss: 0.000023859 (0.000043058) Data Loss: 0.000160899 (0.000167941) BC Loss: 0.000017240 (0.000022870)
Iteration: [ 27001/ 50000] Loss: 0.000229443 (0.000226531) Physics Loss: 0.000049537 (0.000037394) Data Loss: 0.000158349 (0.000167743) BC Loss: 0.000021557 (0.000021393)
Iteration: [ 28001/ 50000] Loss: 0.000207354 (0.000221756) Physics Loss: 0.000035808 (0.000039702) Data Loss: 0.000149648 (0.000160837) BC Loss: 0.000021898 (0.000021218)
Iteration: [ 29001/ 50000] Loss: 0.000213040 (0.000225860) Physics Loss: 0.000035587 (0.000043376) Data Loss: 0.000144921 (0.000160407) BC Loss: 0.000032532 (0.000022076)
Iteration: [ 30001/ 50000] Loss: 0.000234373 (0.000216143) Physics Loss: 0.000054144 (0.000033498) Data Loss: 0.000159639 (0.000162714) BC Loss: 0.000020590 (0.000019931)
Iteration: [ 31001/ 50000] Loss: 0.000255975 (0.000222872) Physics Loss: 0.000020559 (0.000042374) Data Loss: 0.000216269 (0.000160147) BC Loss: 0.000019146 (0.000020351)
Iteration: [ 32001/ 50000] Loss: 0.000200822 (0.000204340) Physics Loss: 0.000025161 (0.000029376) Data Loss: 0.000155289 (0.000155062) BC Loss: 0.000020371 (0.000019902)
Iteration: [ 33001/ 50000] Loss: 0.000178170 (0.000194639) Physics Loss: 0.000022239 (0.000024597) Data Loss: 0.000134938 (0.000150642) BC Loss: 0.000020993 (0.000019400)
Iteration: [ 34001/ 50000] Loss: 0.000193195 (0.000194885) Physics Loss: 0.000028776 (0.000028442) Data Loss: 0.000145554 (0.000147324) BC Loss: 0.000018864 (0.000019119)
Iteration: [ 35001/ 50000] Loss: 0.000142987 (0.000199877) Physics Loss: 0.000025405 (0.000031575) Data Loss: 0.000097128 (0.000149056) BC Loss: 0.000020454 (0.000019246)
Iteration: [ 36001/ 50000] Loss: 0.000160476 (0.000194203) Physics Loss: 0.000026703 (0.000024736) Data Loss: 0.000117877 (0.000149971) BC Loss: 0.000015896 (0.000019496)
Iteration: [ 37001/ 50000] Loss: 0.000283208 (0.000193232) Physics Loss: 0.000042159 (0.000027869) Data Loss: 0.000221629 (0.000146846) BC Loss: 0.000019421 (0.000018517)
Iteration: [ 38001/ 50000] Loss: 0.000245379 (0.000205418) Physics Loss: 0.000033956 (0.000035187) Data Loss: 0.000193104 (0.000152150) BC Loss: 0.000018320 (0.000018081)
Iteration: [ 39001/ 50000] Loss: 0.000177165 (0.000191839) Physics Loss: 0.000036057 (0.000026403) Data Loss: 0.000124458 (0.000146355) BC Loss: 0.000016650 (0.000019081)
Iteration: [ 40001/ 50000] Loss: 0.000197415 (0.000216035) Physics Loss: 0.000044426 (0.000051125) Data Loss: 0.000132311 (0.000145304) BC Loss: 0.000020679 (0.000019607)
Iteration: [ 41001/ 50000] Loss: 0.000157530 (0.000186492) Physics Loss: 0.000018711 (0.000020619) Data Loss: 0.000120454 (0.000148068) BC Loss: 0.000018365 (0.000017805)
Iteration: [ 42001/ 50000] Loss: 0.000173063 (0.000189557) Physics Loss: 0.000025646 (0.000024179) Data Loss: 0.000130955 (0.000146585) BC Loss: 0.000016462 (0.000018794)
Iteration: [ 43001/ 50000] Loss: 0.000214167 (0.000202162) Physics Loss: 0.000045533 (0.000032912) Data Loss: 0.000152926 (0.000147023) BC Loss: 0.000015708 (0.000022226)
Iteration: [ 44001/ 50000] Loss: 0.000171226 (0.000186233) Physics Loss: 0.000015765 (0.000022726) Data Loss: 0.000140696 (0.000143055) BC Loss: 0.000014764 (0.000020452)
Iteration: [ 45001/ 50000] Loss: 0.000259250 (0.000200943) Physics Loss: 0.000025167 (0.000027029) Data Loss: 0.000207201 (0.000152599) BC Loss: 0.000026882 (0.000021315)
Iteration: [ 46001/ 50000] Loss: 0.000182907 (0.000181578) Physics Loss: 0.000011949 (0.000025043) Data Loss: 0.000152303 (0.000138790) BC Loss: 0.000018655 (0.000017745)
Iteration: [ 47001/ 50000] Loss: 0.000167170 (0.000184437) Physics Loss: 0.000012181 (0.000026865) Data Loss: 0.000134047 (0.000139393) BC Loss: 0.000020942 (0.000018180)
Iteration: [ 48001/ 50000] Loss: 0.000186864 (0.000193711) Physics Loss: 0.000027376 (0.000035057) Data Loss: 0.000141050 (0.000138865) BC Loss: 0.000018438 (0.000019789)
Iteration: [ 49001/ 50000] Loss: 0.000179684 (0.000198348) Physics Loss: 0.000037170 (0.000036041) Data Loss: 0.000121571 (0.000142539) BC Loss: 0.000020942 (0.000019768)Visualizing the Results
ts, xs, ys = 0.0f0:0.05f0:2.0f0, 0.0f0:0.02f0:2.0f0, 0.0f0:0.02f0:2.0f0
grid = stack([[elem...] for elem in vec(collect(Iterators.product(xs, ys, ts)))])
u_real = reshape(analytical_solution(grid), length(xs), length(ys), length(ts))
grid_normalized = (grid .- minimum(grid)) ./ (maximum(grid) .- minimum(grid))
u_pred = reshape(trained_model(grid_normalized), length(xs), length(ys), length(ts))
u_pred = u_pred .* (max_pde_val - min_pde_val) .+ min_pde_val
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
errs = [abs.(u_pred[:, :, i] .- u_real[:, :, i]) for i in 1:length(ts)]
Colorbar(fig[1, 2]; limits=extrema(stack(errs)))
CairoMakie.record(fig, "pinn_nested_ad.gif", 1:length(ts); framerate=10) do i
ax.title = "Abs. Predictor Error | Time: $(ts[i])"
err = errs[i]
contour!(ax, xs, ys, err; levels=10, linewidth=2)
heatmap!(ax, xs, ys, err)
return fig
end
fig
endAppendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = LiterateThis page was generated using Literate.jl.