Training a PINN on 2D PDE
In this tutorial we will go over using a PINN to solve 2D PDEs. We will be using the system from NeuralPDE Tutorials. However, we will be using our custom loss function and use nested AD capabilities of Lux.jl.
This is a demonstration of Lux.jl. For serious usecases of PINNs, please refer to the package: NeuralPDE.jl.
Package Imports
using Lux, Optimisers, Random, Printf, Statistics, MLUtils, OnlineStats, CairoMakie,
Reactant, Enzyme
const xdev = reactant_device(; force=true)
const cdev = cpu_device()
(::MLDataDevices.CPUDevice) (generic function with 1 method)
Problem Definition
Since Lux supports efficient nested AD upto 2nd order, we will rewrite the problem with first order derivatives, so that we can compute the gradients of the loss using 2nd order AD.
Define the Neural Networks
All the networks take 3 input variables and output a scalar value. Here, we will define a a wrapper over the 3 networks, so that we can train them using Training.TrainState
.
struct PINN{U, V, W} <: Lux.AbstractLuxContainerLayer{(:u, :v, :w)}
u::U
v::V
w::W
end
function create_mlp(act, hidden_dims)
return Chain(
Dense(3 => hidden_dims, act),
Dense(hidden_dims => hidden_dims, act),
Dense(hidden_dims => hidden_dims, act),
Dense(hidden_dims => 1)
)
end
function PINN(; hidden_dims::Int=32)
return PINN(
create_mlp(tanh, hidden_dims),
create_mlp(tanh, hidden_dims),
create_mlp(tanh, hidden_dims)
)
end
Main.var"##230".PINN
Define the Loss Functions
We will define a custom loss function to compute the loss using 2nd order AD. We will use the following loss function
@views function physics_informed_loss_function(
u::StatefulLuxLayer, v::StatefulLuxLayer, w::StatefulLuxLayer, xyt::AbstractArray
)
∂u_∂xyt = Enzyme.gradient(Enzyme.Reverse, sum ∘ u, xyt)[1]
∂u_∂x, ∂u_∂y, ∂u_∂t = ∂u_∂xyt[1:1, :], ∂u_∂xyt[2:2, :], ∂u_∂xyt[3:3, :]
∂v_∂x = Enzyme.gradient(Enzyme.Reverse, sum ∘ v, xyt)[1][1:1, :]
v_xyt = v(xyt)
∂w_∂y = Enzyme.gradient(Enzyme.Reverse, sum ∘ w, xyt)[1][2:2, :]
w_xyt = w(xyt)
return (
mean(abs2, ∂u_∂t .- ∂v_∂x .- ∂w_∂y) +
mean(abs2, v_xyt .- ∂u_∂x) +
mean(abs2, w_xyt .- ∂u_∂y)
)
end
physics_informed_loss_function (generic function with 1 method)
Additionally, we need to compute the loss wrt the boundary conditions.
function mse_loss_function(u::StatefulLuxLayer, target::AbstractArray, xyt::AbstractArray)
return MSELoss()(u(xyt), target)
end
function loss_function(model, ps, st, (xyt, target_data, xyt_bc, target_bc))
u_net = StatefulLuxLayer{true}(model.u, ps.u, st.u)
v_net = StatefulLuxLayer{true}(model.v, ps.v, st.v)
w_net = StatefulLuxLayer{true}(model.w, ps.w, st.w)
physics_loss = physics_informed_loss_function(u_net, v_net, w_net, xyt)
data_loss = mse_loss_function(u_net, target_data, xyt)
bc_loss = mse_loss_function(u_net, target_bc, xyt_bc)
loss = physics_loss + data_loss + bc_loss
return (
loss,
(; u=u_net.st, v=v_net.st, w=w_net.st),
(; physics_loss, data_loss, bc_loss)
)
end
loss_function (generic function with 1 method)
Generate the Data
We will generate some random data to train the model on. We will take data on a square spatial and temporal domain
analytical_solution(x, y, t) = @. exp(x + y) * cos(x + y + 4t)
analytical_solution(xyt) = analytical_solution(xyt[1, :], xyt[2, :], xyt[3, :])
begin
grid_len = 16
grid = range(0.0f0, 2.0f0; length=grid_len)
xyt = stack([[elem...] for elem in vec(collect(Iterators.product(grid, grid, grid)))])
target_data = reshape(analytical_solution(xyt), 1, :)
bc_len = 512
x = collect(range(0.0f0, 2.0f0; length=bc_len))
y = collect(range(0.0f0, 2.0f0; length=bc_len))
t = collect(range(0.0f0, 2.0f0; length=bc_len))
xyt_bc = hcat(
stack((x, y, zeros(Float32, bc_len)); dims=1),
stack((zeros(Float32, bc_len), y, t); dims=1),
stack((ones(Float32, bc_len) .* 2, y, t); dims=1),
stack((x, zeros(Float32, bc_len), t); dims=1),
stack((x, ones(Float32, bc_len) .* 2, t); dims=1)
)
target_bc = reshape(analytical_solution(xyt_bc), 1, :)
min_target_bc, max_target_bc = extrema(target_bc)
min_data, max_data = extrema(target_data)
min_pde_val, max_pde_val = min(min_data, min_target_bc), max(max_data, max_target_bc)
xyt = (xyt .- minimum(xyt)) ./ (maximum(xyt) .- minimum(xyt))
xyt_bc = (xyt_bc .- minimum(xyt_bc)) ./ (maximum(xyt_bc) .- minimum(xyt_bc))
target_bc = (target_bc .- min_pde_val) ./ (max_pde_val - min_pde_val)
target_data = (target_data .- min_pde_val) ./ (max_pde_val - min_pde_val)
end
Training
function train_model(
xyt, target_data, xyt_bc, target_bc; seed::Int=0,
maxiters::Int=50000, hidden_dims::Int=32
)
rng = Random.default_rng()
Random.seed!(rng, seed)
pinn = PINN(; hidden_dims)
ps, st = Lux.setup(rng, pinn) |> xdev
bc_dataloader = DataLoader(
(xyt_bc, target_bc); batchsize=32, shuffle=true, partial=false
) |> xdev
pde_dataloader = DataLoader(
(xyt, target_data); batchsize=32, shuffle=true, partial=false
) |> xdev
train_state = Training.TrainState(pinn, ps, st, Adam(0.05f0))
lr = i -> i < 5000 ? 0.05f0 : (i < 10000 ? 0.005f0 : 0.0005f0)
total_loss_tracker, physics_loss_tracker, data_loss_tracker, bc_loss_tracker = ntuple(
_ -> OnlineStats.CircBuff(Float32, 32; rev=true), 4)
iter = 1
for ((xyt_batch, target_data_batch), (xyt_bc_batch, target_bc_batch)) in zip(
Iterators.cycle(pde_dataloader), Iterators.cycle(bc_dataloader)
)
Optimisers.adjust!(train_state, lr(iter))
_, loss, stats, train_state = Training.single_train_step!(
AutoEnzyme(), loss_function,
(xyt_batch, target_data_batch, xyt_bc_batch, target_bc_batch),
train_state
)
fit!(total_loss_tracker, Float32(loss))
fit!(physics_loss_tracker, Float32(stats.physics_loss))
fit!(data_loss_tracker, Float32(stats.data_loss))
fit!(bc_loss_tracker, Float32(stats.bc_loss))
mean_loss = mean(OnlineStats.value(total_loss_tracker))
mean_physics_loss = mean(OnlineStats.value(physics_loss_tracker))
mean_data_loss = mean(OnlineStats.value(data_loss_tracker))
mean_bc_loss = mean(OnlineStats.value(bc_loss_tracker))
isnan(loss) && throw(ArgumentError("NaN Loss Detected"))
if iter % 1000 == 1 || iter == maxiters
@printf "Iteration: [%6d/%6d] \t Loss: %.9f (%.9f) \t Physics Loss: %.9f \
(%.9f) \t Data Loss: %.9f (%.9f) \t BC \
Loss: %.9f (%.9f)\n" iter maxiters loss mean_loss stats.physics_loss mean_physics_loss stats.data_loss mean_data_loss stats.bc_loss mean_bc_loss
end
iter += 1
iter ≥ maxiters && break
end
return StatefulLuxLayer{true}(
pinn, cdev(train_state.parameters), cdev(train_state.states)
)
end
trained_model = train_model(xyt, target_data, xyt_bc, target_bc)
trained_u = Lux.testmode(
StatefulLuxLayer{true}(trained_model.model.u, trained_model.ps.u, trained_model.st.u)
)
2025-01-24 06:16:16.628573: I external/xla/xla/service/llvm_ir/llvm_command_line_options.cc:50] XLA (re)initializing LLVM with options fingerprint: 14915741955778251973
E0124 06:16:17.025315 2214143 buffer_comparator.cc:156] Difference at 16: 0, expected 11.6059
E0124 06:16:17.026421 2214143 buffer_comparator.cc:156] Difference at 17: 0, expected 14.502
E0124 06:16:17.026429 2214143 buffer_comparator.cc:156] Difference at 18: 0, expected 11.2449
E0124 06:16:17.026436 2214143 buffer_comparator.cc:156] Difference at 19: 0, expected 10.0998
E0124 06:16:17.026443 2214143 buffer_comparator.cc:156] Difference at 20: 0, expected 14.0222
E0124 06:16:17.026449 2214143 buffer_comparator.cc:156] Difference at 21: 0, expected 10.1321
E0124 06:16:17.026455 2214143 buffer_comparator.cc:156] Difference at 22: 0, expected 10.2986
E0124 06:16:17.026462 2214143 buffer_comparator.cc:156] Difference at 23: 0, expected 14.1109
E0124 06:16:17.026468 2214143 buffer_comparator.cc:156] Difference at 24: 0, expected 13.3463
E0124 06:16:17.026475 2214143 buffer_comparator.cc:156] Difference at 25: 0, expected 12.8369
2025-01-24 06:16:17.026500: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1080] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0124 06:16:17.029313 2214143 buffer_comparator.cc:156] Difference at 16: 0, expected 11.6059
E0124 06:16:17.029342 2214143 buffer_comparator.cc:156] Difference at 17: 0, expected 14.502
E0124 06:16:17.029349 2214143 buffer_comparator.cc:156] Difference at 18: 0, expected 11.2449
E0124 06:16:17.029355 2214143 buffer_comparator.cc:156] Difference at 19: 0, expected 10.0998
E0124 06:16:17.029362 2214143 buffer_comparator.cc:156] Difference at 20: 0, expected 14.0222
E0124 06:16:17.029368 2214143 buffer_comparator.cc:156] Difference at 21: 0, expected 10.1321
E0124 06:16:17.029374 2214143 buffer_comparator.cc:156] Difference at 22: 0, expected 10.2986
E0124 06:16:17.029381 2214143 buffer_comparator.cc:156] Difference at 23: 0, expected 14.1109
E0124 06:16:17.029387 2214143 buffer_comparator.cc:156] Difference at 24: 0, expected 13.3463
E0124 06:16:17.029393 2214143 buffer_comparator.cc:156] Difference at 25: 0, expected 12.8369
2025-01-24 06:16:17.029404: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1080] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0124 06:16:17.032069 2214143 buffer_comparator.cc:156] Difference at 1024: 0, expected 11.5293
E0124 06:16:17.032083 2214143 buffer_comparator.cc:156] Difference at 1025: 0, expected 10.1983
E0124 06:16:17.032086 2214143 buffer_comparator.cc:156] Difference at 1026: 0, expected 13.3385
E0124 06:16:17.032088 2214143 buffer_comparator.cc:156] Difference at 1027: 0, expected 12.4705
E0124 06:16:17.032091 2214143 buffer_comparator.cc:156] Difference at 1028: 0, expected 8.94387
E0124 06:16:17.032094 2214143 buffer_comparator.cc:156] Difference at 1029: 0, expected 10.8997
E0124 06:16:17.032097 2214143 buffer_comparator.cc:156] Difference at 1030: 0, expected 10.6486
E0124 06:16:17.032100 2214143 buffer_comparator.cc:156] Difference at 1031: 0, expected 9.73507
E0124 06:16:17.032102 2214143 buffer_comparator.cc:156] Difference at 1032: 0, expected 12.2806
E0124 06:16:17.032105 2214143 buffer_comparator.cc:156] Difference at 1033: 0, expected 10.1883
2025-01-24 06:16:17.032110: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1080] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0124 06:16:17.034641 2214143 buffer_comparator.cc:156] Difference at 1040: 0, expected 9.99799
E0124 06:16:17.034657 2214143 buffer_comparator.cc:156] Difference at 1041: 0, expected 12.209
E0124 06:16:17.034660 2214143 buffer_comparator.cc:156] Difference at 1042: 0, expected 9.4851
E0124 06:16:17.034663 2214143 buffer_comparator.cc:156] Difference at 1043: 0, expected 8.26397
E0124 06:16:17.034665 2214143 buffer_comparator.cc:156] Difference at 1044: 0, expected 11.9253
E0124 06:16:17.034668 2214143 buffer_comparator.cc:156] Difference at 1045: 0, expected 8.99047
E0124 06:16:17.034672 2214143 buffer_comparator.cc:156] Difference at 1046: 0, expected 8.81842
E0124 06:16:17.034675 2214143 buffer_comparator.cc:156] Difference at 1047: 0, expected 12.2714
E0124 06:16:17.034678 2214143 buffer_comparator.cc:156] Difference at 1048: 0, expected 11.1417
E0124 06:16:17.034681 2214143 buffer_comparator.cc:156] Difference at 1049: 0, expected 10.6572
2025-01-24 06:16:17.034685: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1080] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0124 06:16:17.037223 2214143 buffer_comparator.cc:156] Difference at 1056: 0, expected 10.6543
E0124 06:16:17.037237 2214143 buffer_comparator.cc:156] Difference at 1057: 0, expected 11.0945
E0124 06:16:17.037240 2214143 buffer_comparator.cc:156] Difference at 1058: 0, expected 11.1424
E0124 06:16:17.037242 2214143 buffer_comparator.cc:156] Difference at 1059: 0, expected 12.7556
E0124 06:16:17.037245 2214143 buffer_comparator.cc:156] Difference at 1060: 0, expected 12.6932
E0124 06:16:17.037248 2214143 buffer_comparator.cc:156] Difference at 1061: 0, expected 10.0594
E0124 06:16:17.037251 2214143 buffer_comparator.cc:156] Difference at 1062: 0, expected 12.3478
E0124 06:16:17.037254 2214143 buffer_comparator.cc:156] Difference at 1063: 0, expected 10.8381
E0124 06:16:17.037256 2214143 buffer_comparator.cc:156] Difference at 1064: 0, expected 10.409
E0124 06:16:17.037259 2214143 buffer_comparator.cc:156] Difference at 1065: 0, expected 10.3688
2025-01-24 06:16:17.037264: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1080] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0124 06:16:17.039798 2214143 buffer_comparator.cc:156] Difference at 1056: 0, expected 10.6543
E0124 06:16:17.039811 2214143 buffer_comparator.cc:156] Difference at 1057: 0, expected 11.0945
E0124 06:16:17.039814 2214143 buffer_comparator.cc:156] Difference at 1058: 0, expected 11.1424
E0124 06:16:17.039817 2214143 buffer_comparator.cc:156] Difference at 1059: 0, expected 12.7556
E0124 06:16:17.039820 2214143 buffer_comparator.cc:156] Difference at 1060: 0, expected 12.6932
E0124 06:16:17.039823 2214143 buffer_comparator.cc:156] Difference at 1061: 0, expected 10.0594
E0124 06:16:17.039826 2214143 buffer_comparator.cc:156] Difference at 1062: 0, expected 12.3478
E0124 06:16:17.039828 2214143 buffer_comparator.cc:156] Difference at 1063: 0, expected 10.8381
E0124 06:16:17.039831 2214143 buffer_comparator.cc:156] Difference at 1064: 0, expected 10.409
E0124 06:16:17.039834 2214143 buffer_comparator.cc:156] Difference at 1065: 0, expected 10.3688
2025-01-24 06:16:17.039838: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1080] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0124 06:16:17.042368 2214143 buffer_comparator.cc:156] Difference at 1056: 0, expected 10.6543
E0124 06:16:17.042383 2214143 buffer_comparator.cc:156] Difference at 1057: 0, expected 11.0945
E0124 06:16:17.042386 2214143 buffer_comparator.cc:156] Difference at 1058: 0, expected 11.1424
E0124 06:16:17.042389 2214143 buffer_comparator.cc:156] Difference at 1059: 0, expected 12.7556
E0124 06:16:17.042392 2214143 buffer_comparator.cc:156] Difference at 1060: 0, expected 12.6932
E0124 06:16:17.042395 2214143 buffer_comparator.cc:156] Difference at 1061: 0, expected 10.0594
E0124 06:16:17.042397 2214143 buffer_comparator.cc:156] Difference at 1062: 0, expected 12.3478
E0124 06:16:17.042400 2214143 buffer_comparator.cc:156] Difference at 1063: 0, expected 10.8381
E0124 06:16:17.042403 2214143 buffer_comparator.cc:156] Difference at 1064: 0, expected 10.409
E0124 06:16:17.042406 2214143 buffer_comparator.cc:156] Difference at 1065: 0, expected 10.3688
2025-01-24 06:16:17.042410: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1080] Results do not match the reference. This is likely a bug/unexpected loss of precision.
Iteration: [ 1/ 50000] Loss: 3.158992767 (3.158992767) Physics Loss: 1.982357264 (1.982357264) Data Loss: 0.578243732 (0.578243732) BC Loss: 0.598391712 (0.598391712)
Iteration: [ 1001/ 50000] Loss: 0.034934171 (0.028678153) Physics Loss: 0.000550666 (0.000427971) Data Loss: 0.022407684 (0.011849238) BC Loss: 0.011975820 (0.016400939)
Iteration: [ 2001/ 50000] Loss: 0.029342793 (0.031756539) Physics Loss: 0.001803906 (0.001986223) Data Loss: 0.013016323 (0.012399685) BC Loss: 0.014522564 (0.017370628)
Iteration: [ 3001/ 50000] Loss: 0.021532167 (0.027456025) Physics Loss: 0.002033974 (0.004514552) Data Loss: 0.004263131 (0.008210877) BC Loss: 0.015235063 (0.014730594)
Iteration: [ 4001/ 50000] Loss: 0.049741872 (0.035063371) Physics Loss: 0.008673832 (0.002975470) Data Loss: 0.016625574 (0.012557827) BC Loss: 0.024442466 (0.019530077)
Iteration: [ 5001/ 50000] Loss: 0.021555956 (0.037785888) Physics Loss: 0.002584497 (0.009353531) Data Loss: 0.008988982 (0.011811271) BC Loss: 0.009982477 (0.016621085)
Iteration: [ 6001/ 50000] Loss: 0.022884602 (0.019320106) Physics Loss: 0.000503887 (0.000844118) Data Loss: 0.005671175 (0.006657102) BC Loss: 0.016709540 (0.011818888)
Iteration: [ 7001/ 50000] Loss: 0.017019382 (0.019912010) Physics Loss: 0.000786213 (0.000926178) Data Loss: 0.006546173 (0.007675517) BC Loss: 0.009686996 (0.011310317)
Iteration: [ 8001/ 50000] Loss: 0.022041641 (0.018215429) Physics Loss: 0.002748450 (0.002823828) Data Loss: 0.004308246 (0.004818310) BC Loss: 0.014984946 (0.010573289)
Iteration: [ 9001/ 50000] Loss: 0.020825051 (0.016353965) Physics Loss: 0.001645100 (0.001697728) Data Loss: 0.003023159 (0.004666437) BC Loss: 0.016156793 (0.009989802)
Iteration: [ 10001/ 50000] Loss: 0.012419771 (0.014550619) Physics Loss: 0.001211297 (0.002143627) Data Loss: 0.006038338 (0.003847382) BC Loss: 0.005170135 (0.008559611)
Iteration: [ 11001/ 50000] Loss: 0.016400268 (0.013275324) Physics Loss: 0.001640767 (0.001365761) Data Loss: 0.003816272 (0.003336374) BC Loss: 0.010943229 (0.008573188)
Iteration: [ 12001/ 50000] Loss: 0.012625601 (0.012438955) Physics Loss: 0.001496360 (0.001293942) Data Loss: 0.005017307 (0.003183292) BC Loss: 0.006111934 (0.007961722)
Iteration: [ 13001/ 50000] Loss: 0.006276353 (0.011658882) Physics Loss: 0.000723396 (0.001394610) Data Loss: 0.001822183 (0.002704730) BC Loss: 0.003730775 (0.007559542)
Iteration: [ 14001/ 50000] Loss: 0.007330933 (0.010494760) Physics Loss: 0.001231446 (0.001130134) Data Loss: 0.003328452 (0.002512253) BC Loss: 0.002771034 (0.006852372)
Iteration: [ 15001/ 50000] Loss: 0.012322948 (0.010581179) Physics Loss: 0.001112666 (0.001342141) Data Loss: 0.000318603 (0.002849674) BC Loss: 0.010891679 (0.006389363)
Iteration: [ 16001/ 50000] Loss: 0.005651589 (0.011476245) Physics Loss: 0.001261572 (0.001271797) Data Loss: 0.001857541 (0.002460307) BC Loss: 0.002532476 (0.007744140)
Iteration: [ 17001/ 50000] Loss: 0.004589280 (0.010282698) Physics Loss: 0.000900613 (0.001197141) Data Loss: 0.000789952 (0.002489067) BC Loss: 0.002898715 (0.006596491)
Iteration: [ 18001/ 50000] Loss: 0.013593317 (0.008794412) Physics Loss: 0.001301785 (0.001351097) Data Loss: 0.003135182 (0.001751710) BC Loss: 0.009156350 (0.005691605)
Iteration: [ 19001/ 50000] Loss: 0.009044455 (0.009721650) Physics Loss: 0.001797066 (0.001625932) Data Loss: 0.003098861 (0.001854344) BC Loss: 0.004148528 (0.006241375)
Iteration: [ 20001/ 50000] Loss: 0.004531536 (0.007789471) Physics Loss: 0.002314395 (0.001649175) Data Loss: 0.001491809 (0.001308756) BC Loss: 0.000725332 (0.004831538)
Iteration: [ 21001/ 50000] Loss: 0.004263383 (0.006220151) Physics Loss: 0.001818714 (0.001836546) Data Loss: 0.001014963 (0.001413772) BC Loss: 0.001429706 (0.002969833)
Iteration: [ 22001/ 50000] Loss: 0.006228818 (0.005886041) Physics Loss: 0.003274539 (0.002052142) Data Loss: 0.002419390 (0.000916688) BC Loss: 0.000534889 (0.002917211)
Iteration: [ 23001/ 50000] Loss: 0.004727511 (0.004308250) Physics Loss: 0.001700793 (0.001910425) Data Loss: 0.001145163 (0.000927515) BC Loss: 0.001881554 (0.001470311)
Iteration: [ 24001/ 50000] Loss: 0.005070512 (0.003963022) Physics Loss: 0.001458622 (0.001738792) Data Loss: 0.000377079 (0.000693607) BC Loss: 0.003234812 (0.001530622)
Iteration: [ 25001/ 50000] Loss: 0.002241082 (0.003130429) Physics Loss: 0.000983181 (0.001482069) Data Loss: 0.000401148 (0.000641221) BC Loss: 0.000856754 (0.001007139)
Iteration: [ 26001/ 50000] Loss: 0.002955514 (0.002503448) Physics Loss: 0.001489272 (0.001270391) Data Loss: 0.000252445 (0.000556399) BC Loss: 0.001213797 (0.000676658)
Iteration: [ 27001/ 50000] Loss: 0.003096203 (0.002668173) Physics Loss: 0.001007635 (0.001394353) Data Loss: 0.000529087 (0.000551110) BC Loss: 0.001559482 (0.000722710)
Iteration: [ 28001/ 50000] Loss: 0.001113268 (0.002219696) Physics Loss: 0.000669667 (0.001248004) Data Loss: 0.000210634 (0.000505303) BC Loss: 0.000232967 (0.000466388)
Iteration: [ 29001/ 50000] Loss: 0.001584558 (0.001840178) Physics Loss: 0.000857728 (0.000936380) Data Loss: 0.000405950 (0.000491625) BC Loss: 0.000320880 (0.000412173)
Iteration: [ 30001/ 50000] Loss: 0.001524673 (0.001886792) Physics Loss: 0.000554989 (0.001022693) Data Loss: 0.000461125 (0.000466423) BC Loss: 0.000508558 (0.000397676)
Iteration: [ 31001/ 50000] Loss: 0.001906899 (0.002045441) Physics Loss: 0.000959212 (0.001217105) Data Loss: 0.000290614 (0.000451727) BC Loss: 0.000657073 (0.000376609)
Iteration: [ 32001/ 50000] Loss: 0.003575745 (0.002317760) Physics Loss: 0.002933637 (0.001479890) Data Loss: 0.000278991 (0.000500438) BC Loss: 0.000363118 (0.000337431)
Iteration: [ 33001/ 50000] Loss: 0.001054667 (0.001392871) Physics Loss: 0.000542178 (0.000709339) Data Loss: 0.000279479 (0.000395160) BC Loss: 0.000233010 (0.000288372)
Iteration: [ 34001/ 50000] Loss: 0.001258306 (0.001430126) Physics Loss: 0.000604900 (0.000740890) Data Loss: 0.000279625 (0.000439273) BC Loss: 0.000373781 (0.000249964)
Iteration: [ 35001/ 50000] Loss: 0.001458327 (0.001451151) Physics Loss: 0.001126388 (0.000859504) Data Loss: 0.000209119 (0.000364929) BC Loss: 0.000122820 (0.000226719)
Iteration: [ 36001/ 50000] Loss: 0.001965113 (0.001363893) Physics Loss: 0.000819583 (0.000818373) Data Loss: 0.000884497 (0.000343106) BC Loss: 0.000261034 (0.000202413)
Iteration: [ 37001/ 50000] Loss: 0.001001803 (0.001306025) Physics Loss: 0.000458097 (0.000750440) Data Loss: 0.000366133 (0.000375299) BC Loss: 0.000177574 (0.000180287)
Iteration: [ 38001/ 50000] Loss: 0.001594038 (0.001202732) Physics Loss: 0.000834267 (0.000656006) Data Loss: 0.000438084 (0.000355202) BC Loss: 0.000321687 (0.000191524)
Iteration: [ 39001/ 50000] Loss: 0.001096051 (0.001220055) Physics Loss: 0.000612216 (0.000665410) Data Loss: 0.000219414 (0.000392687) BC Loss: 0.000264422 (0.000161957)
Iteration: [ 40001/ 50000] Loss: 0.001038662 (0.001444635) Physics Loss: 0.000485165 (0.000887432) Data Loss: 0.000353080 (0.000370217) BC Loss: 0.000200418 (0.000186986)
Iteration: [ 41001/ 50000] Loss: 0.000912517 (0.001242174) Physics Loss: 0.000559338 (0.000678301) Data Loss: 0.000281135 (0.000411240) BC Loss: 0.000072045 (0.000152633)
Iteration: [ 42001/ 50000] Loss: 0.001224924 (0.001465745) Physics Loss: 0.001037403 (0.000994991) Data Loss: 0.000128774 (0.000311622) BC Loss: 0.000058747 (0.000159133)
Iteration: [ 43001/ 50000] Loss: 0.000964010 (0.001390206) Physics Loss: 0.000636658 (0.000894072) Data Loss: 0.000201709 (0.000336276) BC Loss: 0.000125642 (0.000159858)
Iteration: [ 44001/ 50000] Loss: 0.000622213 (0.001140883) Physics Loss: 0.000339070 (0.000657654) Data Loss: 0.000136122 (0.000339579) BC Loss: 0.000147021 (0.000143649)
Iteration: [ 45001/ 50000] Loss: 0.000991343 (0.001148664) Physics Loss: 0.000418792 (0.000707322) Data Loss: 0.000449356 (0.000315701) BC Loss: 0.000123195 (0.000125641)
Iteration: [ 46001/ 50000] Loss: 0.000918692 (0.000882140) Physics Loss: 0.000450692 (0.000441957) Data Loss: 0.000271820 (0.000319554) BC Loss: 0.000196180 (0.000120629)
Iteration: [ 47001/ 50000] Loss: 0.001575661 (0.001149394) Physics Loss: 0.001348543 (0.000694363) Data Loss: 0.000169103 (0.000319502) BC Loss: 0.000058015 (0.000135529)
Iteration: [ 48001/ 50000] Loss: 0.001518319 (0.001056366) Physics Loss: 0.001149551 (0.000625095) Data Loss: 0.000176694 (0.000317998) BC Loss: 0.000192074 (0.000113272)
Iteration: [ 49001/ 50000] Loss: 0.002202400 (0.001657575) Physics Loss: 0.001789718 (0.001203460) Data Loss: 0.000310050 (0.000312800) BC Loss: 0.000102632 (0.000141316)
Visualizing the Results
ts, xs, ys = 0.0f0:0.05f0:2.0f0, 0.0f0:0.02f0:2.0f0, 0.0f0:0.02f0:2.0f0
grid = stack([[elem...] for elem in vec(collect(Iterators.product(xs, ys, ts)))])
u_real = reshape(analytical_solution(grid), length(xs), length(ys), length(ts))
grid_normalized = (grid .- minimum(grid)) ./ (maximum(grid) .- minimum(grid))
u_pred = reshape(trained_u(grid_normalized), length(xs), length(ys), length(ts))
u_pred = u_pred .* (max_pde_val - min_pde_val) .+ min_pde_val
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
errs = [abs.(u_pred[:, :, i] .- u_real[:, :, i]) for i in 1:length(ts)]
Colorbar(fig[1, 2]; limits=extrema(stack(errs)))
CairoMakie.record(fig, "pinn_nested_ad.gif", 1:length(ts); framerate=10) do i
ax.title = "Abs. Predictor Error | Time: $(ts[i])"
err = errs[i]
contour!(ax, xs, ys, err; levels=10, linewidth=2)
heatmap!(ax, xs, ys, err)
return fig
end
fig
end
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.