Training a PINN on 2D PDE
In this tutorial we will go over using a PINN to solve 2D PDEs. We will be using the system from NeuralPDE Tutorials. However, we will be using our custom loss function and use nested AD capabilities of Lux.jl.
This is a demonstration of Lux.jl. For serious usecases of PINNs, please refer to the package: NeuralPDE.jl.
Package Imports
using Lux, Optimisers, Random, Printf, Statistics, MLUtils, OnlineStats, CairoMakie,
Reactant, Enzyme
const xdev = reactant_device(; force = true)
const cdev = cpu_device()
(::MLDataDevices.CPUDevice) (generic function with 1 method)
Problem Definition
Since Lux supports efficient nested AD upto 2nd order, we will rewrite the problem with first order derivatives, so that we can compute the gradients of the loss using 2nd order AD.
Define the Neural Networks
All the networks take 3 input variables and output a scalar value. Here, we will define a a wrapper over the 3 networks, so that we can train them using Training.TrainState
.
struct PINN{U, V, W} <: Lux.AbstractLuxContainerLayer{(:u, :v, :w)}
u::U
v::V
w::W
end
function create_mlp(act, hidden_dims)
return Chain(
Dense(3 => hidden_dims, act),
Dense(hidden_dims => hidden_dims, act),
Dense(hidden_dims => hidden_dims, act),
Dense(hidden_dims => 1)
)
end
function PINN(; hidden_dims::Int = 32)
return PINN(
create_mlp(tanh, hidden_dims),
create_mlp(tanh, hidden_dims),
create_mlp(tanh, hidden_dims)
)
end
Main.var"##230".PINN
Define the Loss Functions
We will define a custom loss function to compute the loss using 2nd order AD. We will use the following loss function
@views function physics_informed_loss_function(
u::StatefulLuxLayer, v::StatefulLuxLayer, w::StatefulLuxLayer, xyt::AbstractArray
)
∂u_∂xyt = Enzyme.gradient(Enzyme.Reverse, sum ∘ u, xyt)[1]
∂u_∂x, ∂u_∂y, ∂u_∂t = ∂u_∂xyt[1:1, :], ∂u_∂xyt[2:2, :], ∂u_∂xyt[3:3, :]
∂v_∂x = Enzyme.gradient(Enzyme.Reverse, sum ∘ v, xyt)[1][1:1, :]
v_xyt = v(xyt)
∂w_∂y = Enzyme.gradient(Enzyme.Reverse, sum ∘ w, xyt)[1][2:2, :]
w_xyt = w(xyt)
return (
mean(abs2, ∂u_∂t .- ∂v_∂x .- ∂w_∂y) +
mean(abs2, v_xyt .- ∂u_∂x) +
mean(abs2, w_xyt .- ∂u_∂y)
)
end
physics_informed_loss_function (generic function with 1 method)
Additionally, we need to compute the loss wrt the boundary conditions.
function mse_loss_function(u::StatefulLuxLayer, target::AbstractArray, xyt::AbstractArray)
return MSELoss()(u(xyt), target)
end
function loss_function(model, ps, st, (xyt, target_data, xyt_bc, target_bc))
u_net = StatefulLuxLayer{true}(model.u, ps.u, st.u)
v_net = StatefulLuxLayer{true}(model.v, ps.v, st.v)
w_net = StatefulLuxLayer{true}(model.w, ps.w, st.w)
physics_loss = physics_informed_loss_function(u_net, v_net, w_net, xyt)
data_loss = mse_loss_function(u_net, target_data, xyt)
bc_loss = mse_loss_function(u_net, target_bc, xyt_bc)
loss = physics_loss + data_loss + bc_loss
return (
loss,
(; u = u_net.st, v = v_net.st, w = w_net.st),
(; physics_loss, data_loss, bc_loss),
)
end
loss_function (generic function with 1 method)
Generate the Data
We will generate some random data to train the model on. We will take data on a square spatial and temporal domain
analytical_solution(x, y, t) = @. exp(x + y) * cos(x + y + 4t)
analytical_solution(xyt) = analytical_solution(xyt[1, :], xyt[2, :], xyt[3, :])
begin
grid_len = 16
grid = range(0.0f0, 2.0f0; length = grid_len)
xyt = stack([[elem...] for elem in vec(collect(Iterators.product(grid, grid, grid)))])
target_data = reshape(analytical_solution(xyt), 1, :)
bc_len = 512
x = collect(range(0.0f0, 2.0f0; length = bc_len))
y = collect(range(0.0f0, 2.0f0; length = bc_len))
t = collect(range(0.0f0, 2.0f0; length = bc_len))
xyt_bc = hcat(
stack((x, y, zeros(Float32, bc_len)); dims = 1),
stack((zeros(Float32, bc_len), y, t); dims = 1),
stack((ones(Float32, bc_len) .* 2, y, t); dims = 1),
stack((x, zeros(Float32, bc_len), t); dims = 1),
stack((x, ones(Float32, bc_len) .* 2, t); dims = 1)
)
target_bc = reshape(analytical_solution(xyt_bc), 1, :)
min_target_bc, max_target_bc = extrema(target_bc)
min_data, max_data = extrema(target_data)
min_pde_val, max_pde_val = min(min_data, min_target_bc), max(max_data, max_target_bc)
xyt = (xyt .- minimum(xyt)) ./ (maximum(xyt) .- minimum(xyt))
xyt_bc = (xyt_bc .- minimum(xyt_bc)) ./ (maximum(xyt_bc) .- minimum(xyt_bc))
target_bc = (target_bc .- min_pde_val) ./ (max_pde_val - min_pde_val)
target_data = (target_data .- min_pde_val) ./ (max_pde_val - min_pde_val)
end
Training
function train_model(
xyt, target_data, xyt_bc, target_bc; seed::Int = 0,
maxiters::Int = 50000, hidden_dims::Int = 32
)
rng = Random.default_rng()
Random.seed!(rng, seed)
pinn = PINN(; hidden_dims)
ps, st = Lux.setup(rng, pinn) |> xdev
bc_dataloader = DataLoader(
(xyt_bc, target_bc); batchsize = 32, shuffle = true, partial = false
) |> xdev
pde_dataloader = DataLoader(
(xyt, target_data); batchsize = 32, shuffle = true, partial = false
) |> xdev
train_state = Training.TrainState(pinn, ps, st, Adam(0.05f0))
lr = i -> i < 5000 ? 0.05f0 : (i < 10000 ? 0.005f0 : 0.0005f0)
total_loss_tracker, physics_loss_tracker, data_loss_tracker, bc_loss_tracker = ntuple(
_ -> OnlineStats.CircBuff(Float32, 32; rev = true), 4
)
iter = 1
for ((xyt_batch, target_data_batch), (xyt_bc_batch, target_bc_batch)) in zip(
Iterators.cycle(pde_dataloader), Iterators.cycle(bc_dataloader)
)
Optimisers.adjust!(train_state, lr(iter))
_, loss, stats, train_state = Training.single_train_step!(
AutoEnzyme(), loss_function,
(xyt_batch, target_data_batch, xyt_bc_batch, target_bc_batch),
train_state
)
fit!(total_loss_tracker, Float32(loss))
fit!(physics_loss_tracker, Float32(stats.physics_loss))
fit!(data_loss_tracker, Float32(stats.data_loss))
fit!(bc_loss_tracker, Float32(stats.bc_loss))
mean_loss = mean(OnlineStats.value(total_loss_tracker))
mean_physics_loss = mean(OnlineStats.value(physics_loss_tracker))
mean_data_loss = mean(OnlineStats.value(data_loss_tracker))
mean_bc_loss = mean(OnlineStats.value(bc_loss_tracker))
isnan(loss) && throw(ArgumentError("NaN Loss Detected"))
if iter % 1000 == 1 || iter == maxiters
@printf "Iteration: [%6d/%6d] \t Loss: %.9f (%.9f) \t Physics Loss: %.9f \
(%.9f) \t Data Loss: %.9f (%.9f) \t BC \
Loss: %.9f (%.9f)\n" iter maxiters loss mean_loss stats.physics_loss mean_physics_loss stats.data_loss mean_data_loss stats.bc_loss mean_bc_loss
end
iter += 1
iter ≥ maxiters && break
end
return StatefulLuxLayer{true}(
pinn, cdev(train_state.parameters), cdev(train_state.states)
)
end
trained_model = train_model(xyt, target_data, xyt_bc, target_bc)
trained_u = Lux.testmode(
StatefulLuxLayer{true}(trained_model.model.u, trained_model.ps.u, trained_model.st.u)
)
2025-03-11 23:01:01.951240: I external/xla/xla/service/service.cc:152] XLA service 0xcdcd910 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-03-11 23:01:01.951274: I external/xla/xla/service/service.cc:160] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1741734061.952007 1244328 se_gpu_pjrt_client.cc:951] Using BFC allocator.
I0000 00:00:1741734061.952071 1244328 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1741734061.952120 1244328 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1741734061.964881 1244328 cuda_dnn.cc:529] Loaded cuDNN version 90400
E0000 00:00:1741734307.009729 1244328 buffer_comparator.cc:156] Difference at 16: -nan, expected 11.6059
E0000 00:00:1741734307.011181 1244328 buffer_comparator.cc:156] Difference at 17: -nan, expected 14.502
E0000 00:00:1741734307.011186 1244328 buffer_comparator.cc:156] Difference at 18: -nan, expected 11.2449
E0000 00:00:1741734307.011189 1244328 buffer_comparator.cc:156] Difference at 19: -nan, expected 10.0998
E0000 00:00:1741734307.011192 1244328 buffer_comparator.cc:156] Difference at 20: -nan, expected 14.0222
E0000 00:00:1741734307.011194 1244328 buffer_comparator.cc:156] Difference at 21: -nan, expected 10.1321
E0000 00:00:1741734307.011197 1244328 buffer_comparator.cc:156] Difference at 22: -nan, expected 10.2986
E0000 00:00:1741734307.011200 1244328 buffer_comparator.cc:156] Difference at 23: -nan, expected 14.1109
E0000 00:00:1741734307.011203 1244328 buffer_comparator.cc:156] Difference at 24: -nan, expected 13.3463
E0000 00:00:1741734307.011206 1244328 buffer_comparator.cc:156] Difference at 25: -nan, expected 12.8369
2025-03-11 23:05:07.011215: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741734307.014000 1244328 buffer_comparator.cc:156] Difference at 16: -nan, expected 11.6059
E0000 00:00:1741734307.014011 1244328 buffer_comparator.cc:156] Difference at 17: -nan, expected 14.502
E0000 00:00:1741734307.014015 1244328 buffer_comparator.cc:156] Difference at 18: -nan, expected 11.2449
E0000 00:00:1741734307.014018 1244328 buffer_comparator.cc:156] Difference at 19: -nan, expected 10.0998
E0000 00:00:1741734307.014020 1244328 buffer_comparator.cc:156] Difference at 20: -nan, expected 14.0222
E0000 00:00:1741734307.014023 1244328 buffer_comparator.cc:156] Difference at 21: -nan, expected 10.1321
E0000 00:00:1741734307.014026 1244328 buffer_comparator.cc:156] Difference at 22: -nan, expected 10.2986
E0000 00:00:1741734307.014029 1244328 buffer_comparator.cc:156] Difference at 23: -nan, expected 14.1109
E0000 00:00:1741734307.014032 1244328 buffer_comparator.cc:156] Difference at 24: -nan, expected 13.3463
E0000 00:00:1741734307.014034 1244328 buffer_comparator.cc:156] Difference at 25: -nan, expected 12.8369
2025-03-11 23:05:07.014039: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741734307.016827 1244328 buffer_comparator.cc:156] Difference at 16: -nan, expected 11.6059
E0000 00:00:1741734307.016838 1244328 buffer_comparator.cc:156] Difference at 17: -nan, expected 14.502
E0000 00:00:1741734307.016841 1244328 buffer_comparator.cc:156] Difference at 18: -nan, expected 11.2449
E0000 00:00:1741734307.016844 1244328 buffer_comparator.cc:156] Difference at 19: -nan, expected 10.0998
E0000 00:00:1741734307.016847 1244328 buffer_comparator.cc:156] Difference at 20: -nan, expected 14.0222
E0000 00:00:1741734307.016850 1244328 buffer_comparator.cc:156] Difference at 21: -nan, expected 10.1321
E0000 00:00:1741734307.016853 1244328 buffer_comparator.cc:156] Difference at 22: -nan, expected 10.2986
E0000 00:00:1741734307.016856 1244328 buffer_comparator.cc:156] Difference at 23: -nan, expected 14.1109
E0000 00:00:1741734307.016860 1244328 buffer_comparator.cc:156] Difference at 24: -nan, expected 13.3463
E0000 00:00:1741734307.016863 1244328 buffer_comparator.cc:156] Difference at 25: -nan, expected 12.8369
2025-03-11 23:05:07.016868: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741734307.019641 1244328 buffer_comparator.cc:156] Difference at 32: -nan, expected 12.4
E0000 00:00:1741734307.019653 1244328 buffer_comparator.cc:156] Difference at 33: -nan, expected 12.9454
E0000 00:00:1741734307.019656 1244328 buffer_comparator.cc:156] Difference at 34: -nan, expected 12.9462
E0000 00:00:1741734307.019659 1244328 buffer_comparator.cc:156] Difference at 35: -nan, expected 13.9775
E0000 00:00:1741734307.019662 1244328 buffer_comparator.cc:156] Difference at 36: -nan, expected 15.0433
E0000 00:00:1741734307.019665 1244328 buffer_comparator.cc:156] Difference at 37: -nan, expected 12.0589
E0000 00:00:1741734307.019668 1244328 buffer_comparator.cc:156] Difference at 38: -nan, expected 14.4629
E0000 00:00:1741734307.019670 1244328 buffer_comparator.cc:156] Difference at 39: -nan, expected 12.7671
E0000 00:00:1741734307.019673 1244328 buffer_comparator.cc:156] Difference at 40: -nan, expected 12.3584
E0000 00:00:1741734307.019676 1244328 buffer_comparator.cc:156] Difference at 41: -nan, expected 11.6002
2025-03-11 23:05:07.019681: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741734307.022453 1244328 buffer_comparator.cc:156] Difference at 32: -nan, expected 12.4
E0000 00:00:1741734307.022464 1244328 buffer_comparator.cc:156] Difference at 33: -nan, expected 12.9454
E0000 00:00:1741734307.022468 1244328 buffer_comparator.cc:156] Difference at 34: -nan, expected 12.9462
E0000 00:00:1741734307.022471 1244328 buffer_comparator.cc:156] Difference at 35: -nan, expected 13.9775
E0000 00:00:1741734307.022473 1244328 buffer_comparator.cc:156] Difference at 36: -nan, expected 15.0433
E0000 00:00:1741734307.022476 1244328 buffer_comparator.cc:156] Difference at 37: -nan, expected 12.0589
E0000 00:00:1741734307.022479 1244328 buffer_comparator.cc:156] Difference at 38: -nan, expected 14.4629
E0000 00:00:1741734307.022482 1244328 buffer_comparator.cc:156] Difference at 39: -nan, expected 12.7671
E0000 00:00:1741734307.022485 1244328 buffer_comparator.cc:156] Difference at 40: -nan, expected 12.3584
E0000 00:00:1741734307.022487 1244328 buffer_comparator.cc:156] Difference at 41: -nan, expected 11.6002
2025-03-11 23:05:07.022492: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741734307.025254 1244328 buffer_comparator.cc:156] Difference at 32: -nan, expected 12.4
E0000 00:00:1741734307.025265 1244328 buffer_comparator.cc:156] Difference at 33: -nan, expected 12.9454
E0000 00:00:1741734307.025268 1244328 buffer_comparator.cc:156] Difference at 34: -nan, expected 12.9462
E0000 00:00:1741734307.025271 1244328 buffer_comparator.cc:156] Difference at 35: -nan, expected 13.9775
E0000 00:00:1741734307.025274 1244328 buffer_comparator.cc:156] Difference at 36: -nan, expected 15.0433
E0000 00:00:1741734307.025277 1244328 buffer_comparator.cc:156] Difference at 37: -nan, expected 12.0589
E0000 00:00:1741734307.025280 1244328 buffer_comparator.cc:156] Difference at 38: -nan, expected 14.4629
E0000 00:00:1741734307.025282 1244328 buffer_comparator.cc:156] Difference at 39: -nan, expected 12.7671
E0000 00:00:1741734307.025285 1244328 buffer_comparator.cc:156] Difference at 40: -nan, expected 12.3584
E0000 00:00:1741734307.025288 1244328 buffer_comparator.cc:156] Difference at 41: -nan, expected 11.6002
2025-03-11 23:05:07.025292: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
Iteration: [ 1/ 50000] Loss: 3.158992767 (3.158992767) Physics Loss: 1.982357264 (1.982357264) Data Loss: 0.578243732 (0.578243732) BC Loss: 0.598391712 (0.598391712)
Iteration: [ 1001/ 50000] Loss: 0.034934171 (0.028678153) Physics Loss: 0.000550666 (0.000427971) Data Loss: 0.022407684 (0.011849238) BC Loss: 0.011975820 (0.016400939)
Iteration: [ 2001/ 50000] Loss: 0.029342793 (0.031756539) Physics Loss: 0.001803906 (0.001986223) Data Loss: 0.013016323 (0.012399685) BC Loss: 0.014522564 (0.017370628)
Iteration: [ 3001/ 50000] Loss: 0.021532167 (0.027456025) Physics Loss: 0.002033974 (0.004514552) Data Loss: 0.004263131 (0.008210877) BC Loss: 0.015235063 (0.014730594)
Iteration: [ 4001/ 50000] Loss: 0.049741872 (0.035063371) Physics Loss: 0.008673832 (0.002975470) Data Loss: 0.016625574 (0.012557827) BC Loss: 0.024442466 (0.019530077)
Iteration: [ 5001/ 50000] Loss: 0.021555956 (0.037785888) Physics Loss: 0.002584497 (0.009353531) Data Loss: 0.008988982 (0.011811271) BC Loss: 0.009982477 (0.016621085)
Iteration: [ 6001/ 50000] Loss: 0.022884602 (0.019320106) Physics Loss: 0.000503887 (0.000844118) Data Loss: 0.005671175 (0.006657102) BC Loss: 0.016709540 (0.011818888)
Iteration: [ 7001/ 50000] Loss: 0.017019382 (0.019912010) Physics Loss: 0.000786213 (0.000926178) Data Loss: 0.006546173 (0.007675517) BC Loss: 0.009686996 (0.011310317)
Iteration: [ 8001/ 50000] Loss: 0.022041641 (0.018215429) Physics Loss: 0.002748450 (0.002823828) Data Loss: 0.004308246 (0.004818310) BC Loss: 0.014984946 (0.010573289)
Iteration: [ 9001/ 50000] Loss: 0.020825051 (0.016353965) Physics Loss: 0.001645100 (0.001697728) Data Loss: 0.003023159 (0.004666437) BC Loss: 0.016156793 (0.009989802)
Iteration: [ 10001/ 50000] Loss: 0.012419771 (0.014550619) Physics Loss: 0.001211297 (0.002143627) Data Loss: 0.006038338 (0.003847382) BC Loss: 0.005170135 (0.008559611)
Iteration: [ 11001/ 50000] Loss: 0.016400268 (0.013275324) Physics Loss: 0.001640767 (0.001365761) Data Loss: 0.003816272 (0.003336374) BC Loss: 0.010943229 (0.008573188)
Iteration: [ 12001/ 50000] Loss: 0.012625601 (0.012438955) Physics Loss: 0.001496360 (0.001293942) Data Loss: 0.005017307 (0.003183292) BC Loss: 0.006111934 (0.007961722)
Iteration: [ 13001/ 50000] Loss: 0.006276353 (0.011658882) Physics Loss: 0.000723396 (0.001394610) Data Loss: 0.001822183 (0.002704730) BC Loss: 0.003730775 (0.007559542)
Iteration: [ 14001/ 50000] Loss: 0.007330933 (0.010494760) Physics Loss: 0.001231446 (0.001130134) Data Loss: 0.003328452 (0.002512253) BC Loss: 0.002771034 (0.006852372)
Iteration: [ 15001/ 50000] Loss: 0.012322948 (0.010581179) Physics Loss: 0.001112666 (0.001342141) Data Loss: 0.000318603 (0.002849674) BC Loss: 0.010891679 (0.006389363)
Iteration: [ 16001/ 50000] Loss: 0.005651589 (0.011476245) Physics Loss: 0.001261572 (0.001271797) Data Loss: 0.001857541 (0.002460307) BC Loss: 0.002532476 (0.007744140)
Iteration: [ 17001/ 50000] Loss: 0.004589280 (0.010282698) Physics Loss: 0.000900613 (0.001197141) Data Loss: 0.000789952 (0.002489067) BC Loss: 0.002898715 (0.006596491)
Iteration: [ 18001/ 50000] Loss: 0.013593317 (0.008794412) Physics Loss: 0.001301785 (0.001351097) Data Loss: 0.003135182 (0.001751710) BC Loss: 0.009156350 (0.005691605)
Iteration: [ 19001/ 50000] Loss: 0.009044455 (0.009721650) Physics Loss: 0.001797066 (0.001625932) Data Loss: 0.003098861 (0.001854344) BC Loss: 0.004148528 (0.006241375)
Iteration: [ 20001/ 50000] Loss: 0.004531536 (0.007789471) Physics Loss: 0.002314395 (0.001649175) Data Loss: 0.001491809 (0.001308756) BC Loss: 0.000725332 (0.004831538)
Iteration: [ 21001/ 50000] Loss: 0.004263383 (0.006220151) Physics Loss: 0.001818714 (0.001836546) Data Loss: 0.001014963 (0.001413772) BC Loss: 0.001429706 (0.002969833)
Iteration: [ 22001/ 50000] Loss: 0.006228818 (0.005886041) Physics Loss: 0.003274539 (0.002052142) Data Loss: 0.002419390 (0.000916688) BC Loss: 0.000534889 (0.002917211)
Iteration: [ 23001/ 50000] Loss: 0.004727511 (0.004308250) Physics Loss: 0.001700793 (0.001910425) Data Loss: 0.001145163 (0.000927515) BC Loss: 0.001881554 (0.001470311)
Iteration: [ 24001/ 50000] Loss: 0.005070512 (0.003963022) Physics Loss: 0.001458622 (0.001738792) Data Loss: 0.000377079 (0.000693607) BC Loss: 0.003234812 (0.001530622)
Iteration: [ 25001/ 50000] Loss: 0.002241082 (0.003130429) Physics Loss: 0.000983181 (0.001482069) Data Loss: 0.000401148 (0.000641221) BC Loss: 0.000856754 (0.001007139)
Iteration: [ 26001/ 50000] Loss: 0.002955514 (0.002503448) Physics Loss: 0.001489272 (0.001270391) Data Loss: 0.000252445 (0.000556399) BC Loss: 0.001213797 (0.000676658)
Iteration: [ 27001/ 50000] Loss: 0.003096203 (0.002668173) Physics Loss: 0.001007635 (0.001394353) Data Loss: 0.000529087 (0.000551110) BC Loss: 0.001559482 (0.000722710)
Iteration: [ 28001/ 50000] Loss: 0.001113268 (0.002219696) Physics Loss: 0.000669667 (0.001248004) Data Loss: 0.000210634 (0.000505303) BC Loss: 0.000232967 (0.000466388)
Iteration: [ 29001/ 50000] Loss: 0.001584558 (0.001840178) Physics Loss: 0.000857728 (0.000936380) Data Loss: 0.000405950 (0.000491625) BC Loss: 0.000320880 (0.000412173)
Iteration: [ 30001/ 50000] Loss: 0.001524673 (0.001886792) Physics Loss: 0.000554989 (0.001022693) Data Loss: 0.000461125 (0.000466423) BC Loss: 0.000508558 (0.000397676)
Iteration: [ 31001/ 50000] Loss: 0.001906899 (0.002045441) Physics Loss: 0.000959212 (0.001217105) Data Loss: 0.000290614 (0.000451727) BC Loss: 0.000657073 (0.000376609)
Iteration: [ 32001/ 50000] Loss: 0.003575745 (0.002317760) Physics Loss: 0.002933637 (0.001479890) Data Loss: 0.000278991 (0.000500438) BC Loss: 0.000363118 (0.000337431)
Iteration: [ 33001/ 50000] Loss: 0.001054667 (0.001392871) Physics Loss: 0.000542178 (0.000709339) Data Loss: 0.000279479 (0.000395160) BC Loss: 0.000233010 (0.000288372)
Iteration: [ 34001/ 50000] Loss: 0.001258306 (0.001430126) Physics Loss: 0.000604900 (0.000740890) Data Loss: 0.000279625 (0.000439273) BC Loss: 0.000373781 (0.000249964)
Iteration: [ 35001/ 50000] Loss: 0.001458327 (0.001451151) Physics Loss: 0.001126388 (0.000859504) Data Loss: 0.000209119 (0.000364929) BC Loss: 0.000122820 (0.000226719)
Iteration: [ 36001/ 50000] Loss: 0.001965113 (0.001363893) Physics Loss: 0.000819583 (0.000818373) Data Loss: 0.000884497 (0.000343106) BC Loss: 0.000261034 (0.000202413)
Iteration: [ 37001/ 50000] Loss: 0.001001803 (0.001306025) Physics Loss: 0.000458097 (0.000750440) Data Loss: 0.000366133 (0.000375299) BC Loss: 0.000177574 (0.000180287)
Iteration: [ 38001/ 50000] Loss: 0.001594038 (0.001202732) Physics Loss: 0.000834267 (0.000656006) Data Loss: 0.000438084 (0.000355202) BC Loss: 0.000321687 (0.000191524)
Iteration: [ 39001/ 50000] Loss: 0.001096051 (0.001220055) Physics Loss: 0.000612216 (0.000665410) Data Loss: 0.000219414 (0.000392687) BC Loss: 0.000264422 (0.000161957)
Iteration: [ 40001/ 50000] Loss: 0.001038662 (0.001444635) Physics Loss: 0.000485165 (0.000887432) Data Loss: 0.000353080 (0.000370217) BC Loss: 0.000200418 (0.000186986)
Iteration: [ 41001/ 50000] Loss: 0.000912517 (0.001242174) Physics Loss: 0.000559338 (0.000678301) Data Loss: 0.000281135 (0.000411240) BC Loss: 0.000072045 (0.000152633)
Iteration: [ 42001/ 50000] Loss: 0.001224924 (0.001465745) Physics Loss: 0.001037403 (0.000994991) Data Loss: 0.000128774 (0.000311622) BC Loss: 0.000058747 (0.000159133)
Iteration: [ 43001/ 50000] Loss: 0.000964010 (0.001390206) Physics Loss: 0.000636658 (0.000894072) Data Loss: 0.000201709 (0.000336276) BC Loss: 0.000125642 (0.000159858)
Iteration: [ 44001/ 50000] Loss: 0.000622213 (0.001140883) Physics Loss: 0.000339070 (0.000657654) Data Loss: 0.000136122 (0.000339579) BC Loss: 0.000147021 (0.000143649)
Iteration: [ 45001/ 50000] Loss: 0.000991343 (0.001148664) Physics Loss: 0.000418792 (0.000707322) Data Loss: 0.000449356 (0.000315701) BC Loss: 0.000123195 (0.000125641)
Iteration: [ 46001/ 50000] Loss: 0.000918692 (0.000882140) Physics Loss: 0.000450692 (0.000441957) Data Loss: 0.000271820 (0.000319554) BC Loss: 0.000196180 (0.000120629)
Iteration: [ 47001/ 50000] Loss: 0.001575661 (0.001149394) Physics Loss: 0.001348543 (0.000694363) Data Loss: 0.000169103 (0.000319502) BC Loss: 0.000058015 (0.000135529)
Iteration: [ 48001/ 50000] Loss: 0.001518319 (0.001056366) Physics Loss: 0.001149551 (0.000625095) Data Loss: 0.000176694 (0.000317998) BC Loss: 0.000192074 (0.000113272)
Iteration: [ 49001/ 50000] Loss: 0.002202400 (0.001657575) Physics Loss: 0.001789718 (0.001203460) Data Loss: 0.000310050 (0.000312800) BC Loss: 0.000102632 (0.000141316)
Visualizing the Results
ts, xs, ys = 0.0f0:0.05f0:2.0f0, 0.0f0:0.02f0:2.0f0, 0.0f0:0.02f0:2.0f0
grid = stack([[elem...] for elem in vec(collect(Iterators.product(xs, ys, ts)))])
u_real = reshape(analytical_solution(grid), length(xs), length(ys), length(ts))
grid_normalized = (grid .- minimum(grid)) ./ (maximum(grid) .- minimum(grid))
u_pred = reshape(trained_u(grid_normalized), length(xs), length(ys), length(ts))
u_pred = u_pred .* (max_pde_val - min_pde_val) .+ min_pde_val
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel = "x", ylabel = "y")
errs = [abs.(u_pred[:, :, i] .- u_real[:, :, i]) for i in 1:length(ts)]
Colorbar(fig[1, 2]; limits = extrema(stack(errs)))
CairoMakie.record(fig, "pinn_nested_ad.gif", 1:length(ts); framerate = 10) do i
ax.title = "Abs. Predictor Error | Time: $(ts[i])"
err = errs[i]
contour!(ax, xs, ys, err; levels = 10, linewidth = 2)
heatmap!(ax, xs, ys, err)
return fig
end
fig
end
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.