Skip to content

Training a PINN on 2D PDE

In this tutorial we will go over using a PINN to solve 2D PDEs. We will be using the system from NeuralPDE Tutorials. However, we will be using our custom loss function and use nested AD capabilities of Lux.jl.

This is a demonstration of Lux.jl. For serious usecases of PINNs, please refer to the package: NeuralPDE.jl.

Package Imports

julia
using Lux, Optimisers, Random, Printf, Statistics, MLUtils, OnlineStats, CairoMakie,
      Reactant, Enzyme

const xdev = reactant_device(; force=true)
const cdev = cpu_device()
(::MLDataDevices.CPUDevice) (generic function with 1 method)

Problem Definition

Since Lux supports efficient nested AD upto 2nd order, we will rewrite the problem with first order derivatives, so that we can compute the gradients of the loss using 2nd order AD.

Define the Neural Networks

All the networks take 3 input variables and output a scalar value. Here, we will define a a wrapper over the 3 networks, so that we can train them using Training.TrainState.

julia
struct PINN{U, V, W} <: Lux.AbstractLuxContainerLayer{(:u, :v, :w)}
    u::U
    v::V
    w::W
end

function create_mlp(act, hidden_dims)
    return Chain(
        Dense(3 => hidden_dims, act),
        Dense(hidden_dims => hidden_dims, act),
        Dense(hidden_dims => hidden_dims, act),
        Dense(hidden_dims => 1)
    )
end

function PINN(; hidden_dims::Int=32)
    return PINN(
        create_mlp(tanh, hidden_dims),
        create_mlp(tanh, hidden_dims),
        create_mlp(tanh, hidden_dims)
    )
end
Main.var"##230".PINN

Define the Loss Functions

We will define a custom loss function to compute the loss using 2nd order AD. We will use the following loss function

julia
@views function physics_informed_loss_function(
        u::StatefulLuxLayer, v::StatefulLuxLayer, w::StatefulLuxLayer, xyt::AbstractArray
)
    ∂u_∂xyt = Enzyme.gradient(Enzyme.Reverse, sum  u, xyt)[1]
    ∂u_∂x, ∂u_∂y, ∂u_∂t = ∂u_∂xyt[1:1, :], ∂u_∂xyt[2:2, :], ∂u_∂xyt[3:3, :]
    ∂v_∂x = Enzyme.gradient(Enzyme.Reverse, sum  v, xyt)[1][1:1, :]
    v_xyt = v(xyt)
    ∂w_∂y = Enzyme.gradient(Enzyme.Reverse, sum  w, xyt)[1][2:2, :]
    w_xyt = w(xyt)
    return (
        mean(abs2, ∂u_∂t .- ∂v_∂x .- ∂w_∂y) +
        mean(abs2, v_xyt .- ∂u_∂x) +
        mean(abs2, w_xyt .- ∂u_∂y)
    )
end
physics_informed_loss_function (generic function with 1 method)

Additionally, we need to compute the loss wrt the boundary conditions.

julia
function mse_loss_function(u::StatefulLuxLayer, target::AbstractArray, xyt::AbstractArray)
    return MSELoss()(u(xyt), target)
end

function loss_function(model, ps, st, (xyt, target_data, xyt_bc, target_bc))
    u_net = StatefulLuxLayer{true}(model.u, ps.u, st.u)
    v_net = StatefulLuxLayer{true}(model.v, ps.v, st.v)
    w_net = StatefulLuxLayer{true}(model.w, ps.w, st.w)
    physics_loss = physics_informed_loss_function(u_net, v_net, w_net, xyt)
    data_loss = mse_loss_function(u_net, target_data, xyt)
    bc_loss = mse_loss_function(u_net, target_bc, xyt_bc)
    loss = physics_loss + data_loss + bc_loss
    return (
        loss,
        (; u=u_net.st, v=v_net.st, w=w_net.st),
        (; physics_loss, data_loss, bc_loss)
    )
end
loss_function (generic function with 1 method)

Generate the Data

We will generate some random data to train the model on. We will take data on a square spatial and temporal domain x[0,2], y[0,2], and t[0,2]. Typically, you want to be smarter about the sampling process, but for the sake of simplicity, we will skip that.

julia
analytical_solution(x, y, t) = @. exp(x + y) * cos(x + y + 4t)
analytical_solution(xyt) = analytical_solution(xyt[1, :], xyt[2, :], xyt[3, :])

begin
    grid_len = 16

    grid = range(0.0f0, 2.0f0; length=grid_len)
    xyt = stack([[elem...] for elem in vec(collect(Iterators.product(grid, grid, grid)))])

    target_data = reshape(analytical_solution(xyt), 1, :)

    bc_len = 512

    x = collect(range(0.0f0, 2.0f0; length=bc_len))
    y = collect(range(0.0f0, 2.0f0; length=bc_len))
    t = collect(range(0.0f0, 2.0f0; length=bc_len))

    xyt_bc = hcat(
        stack((x, y, zeros(Float32, bc_len)); dims=1),
        stack((zeros(Float32, bc_len), y, t); dims=1),
        stack((ones(Float32, bc_len) .* 2, y, t); dims=1),
        stack((x, zeros(Float32, bc_len), t); dims=1),
        stack((x, ones(Float32, bc_len) .* 2, t); dims=1)
    )
    target_bc = reshape(analytical_solution(xyt_bc), 1, :)

    min_target_bc, max_target_bc = extrema(target_bc)
    min_data, max_data = extrema(target_data)
    min_pde_val, max_pde_val = min(min_data, min_target_bc), max(max_data, max_target_bc)

    xyt = (xyt .- minimum(xyt)) ./ (maximum(xyt) .- minimum(xyt))
    xyt_bc = (xyt_bc .- minimum(xyt_bc)) ./ (maximum(xyt_bc) .- minimum(xyt_bc))
    target_bc = (target_bc .- min_pde_val) ./ (max_pde_val - min_pde_val)
    target_data = (target_data .- min_pde_val) ./ (max_pde_val - min_pde_val)
end

Training

julia
function train_model(
        xyt, target_data, xyt_bc, target_bc; seed::Int=0,
        maxiters::Int=50000, hidden_dims::Int=32
)
    rng = Random.default_rng()
    Random.seed!(rng, seed)

    pinn = PINN(; hidden_dims)
    ps, st = Lux.setup(rng, pinn) |> xdev

    bc_dataloader = DataLoader(
        (xyt_bc, target_bc); batchsize=32, shuffle=true, partial=false
    ) |> xdev
    pde_dataloader = DataLoader(
        (xyt, target_data); batchsize=32, shuffle=true, partial=false
    ) |> xdev

    train_state = Training.TrainState(pinn, ps, st, Adam(0.05f0))
    lr = i -> i < 5000 ? 0.05f0 : (i < 10000 ? 0.005f0 : 0.0005f0)

    total_loss_tracker, physics_loss_tracker, data_loss_tracker, bc_loss_tracker = ntuple(
        _ -> OnlineStats.CircBuff(Float32, 32; rev=true), 4)

    iter = 1
    for ((xyt_batch, target_data_batch), (xyt_bc_batch, target_bc_batch)) in zip(
        Iterators.cycle(pde_dataloader), Iterators.cycle(bc_dataloader)
    )
        Optimisers.adjust!(train_state, lr(iter))

        _, loss, stats, train_state = Training.single_train_step!(
            AutoEnzyme(), loss_function,
            (xyt_batch, target_data_batch, xyt_bc_batch, target_bc_batch),
            train_state
        )

        fit!(total_loss_tracker, Float32(loss))
        fit!(physics_loss_tracker, Float32(stats.physics_loss))
        fit!(data_loss_tracker, Float32(stats.data_loss))
        fit!(bc_loss_tracker, Float32(stats.bc_loss))

        mean_loss = mean(OnlineStats.value(total_loss_tracker))
        mean_physics_loss = mean(OnlineStats.value(physics_loss_tracker))
        mean_data_loss = mean(OnlineStats.value(data_loss_tracker))
        mean_bc_loss = mean(OnlineStats.value(bc_loss_tracker))

        isnan(loss) && throw(ArgumentError("NaN Loss Detected"))

        if iter % 1000 == 1 || iter == maxiters
            @printf "Iteration: [%6d/%6d] \t Loss: %.9f (%.9f) \t Physics Loss: %.9f \
                     (%.9f) \t Data Loss: %.9f (%.9f) \t BC \
                     Loss: %.9f (%.9f)\n" iter maxiters loss mean_loss stats.physics_loss mean_physics_loss stats.data_loss mean_data_loss stats.bc_loss mean_bc_loss
        end

        iter += 1
        iter  maxiters && break
    end

    return StatefulLuxLayer{true}(
        pinn, cdev(train_state.parameters), cdev(train_state.states)
    )
end

trained_model = train_model(xyt, target_data, xyt_bc, target_bc)
trained_u = Lux.testmode(
    StatefulLuxLayer{true}(trained_model.model.u, trained_model.ps.u, trained_model.st.u)
)
2025-02-05 19:41:16.910393: I external/xla/xla/service/llvm_ir/llvm_command_line_options.cc:51] XLA (re)initializing LLVM with options fingerprint: 3139691659875067101
E0205 19:41:17.242475 3430985 buffer_comparator.cc:156] Difference at 16: -nan, expected 11.6059
E0205 19:41:17.243367 3430985 buffer_comparator.cc:156] Difference at 17: -nan, expected 14.502
E0205 19:41:17.243373 3430985 buffer_comparator.cc:156] Difference at 18: -nan, expected 11.2449
E0205 19:41:17.243377 3430985 buffer_comparator.cc:156] Difference at 19: -nan, expected 10.0998
E0205 19:41:17.243380 3430985 buffer_comparator.cc:156] Difference at 20: -nan, expected 14.0222
E0205 19:41:17.243384 3430985 buffer_comparator.cc:156] Difference at 21: -nan, expected 10.1321
E0205 19:41:17.243388 3430985 buffer_comparator.cc:156] Difference at 22: -nan, expected 10.2986
E0205 19:41:17.243392 3430985 buffer_comparator.cc:156] Difference at 23: -nan, expected 14.1109
E0205 19:41:17.243396 3430985 buffer_comparator.cc:156] Difference at 24: -nan, expected 13.3463
E0205 19:41:17.243399 3430985 buffer_comparator.cc:156] Difference at 25: -nan, expected 12.8369
2025-02-05 19:41:17.243414: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1081] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0205 19:41:17.246016 3430985 buffer_comparator.cc:156] Difference at 16: -nan, expected 11.6059
E0205 19:41:17.246035 3430985 buffer_comparator.cc:156] Difference at 17: -nan, expected 14.502
E0205 19:41:17.246039 3430985 buffer_comparator.cc:156] Difference at 18: -nan, expected 11.2449
E0205 19:41:17.246043 3430985 buffer_comparator.cc:156] Difference at 19: -nan, expected 10.0998
E0205 19:41:17.246046 3430985 buffer_comparator.cc:156] Difference at 20: -nan, expected 14.0222
E0205 19:41:17.246061 3430985 buffer_comparator.cc:156] Difference at 21: -nan, expected 10.1321
E0205 19:41:17.246065 3430985 buffer_comparator.cc:156] Difference at 22: -nan, expected 10.2986
E0205 19:41:17.246068 3430985 buffer_comparator.cc:156] Difference at 23: -nan, expected 14.1109
E0205 19:41:17.246072 3430985 buffer_comparator.cc:156] Difference at 24: -nan, expected 13.3463
E0205 19:41:17.246076 3430985 buffer_comparator.cc:156] Difference at 25: -nan, expected 12.8369
2025-02-05 19:41:17.246082: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1081] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0205 19:41:17.248638 3430985 buffer_comparator.cc:156] Difference at 1024: -nan, expected 11.5293
E0205 19:41:17.248659 3430985 buffer_comparator.cc:156] Difference at 1025: -nan, expected 10.1983
E0205 19:41:17.248663 3430985 buffer_comparator.cc:156] Difference at 1026: -nan, expected 13.3385
E0205 19:41:17.248667 3430985 buffer_comparator.cc:156] Difference at 1027: -nan, expected 12.4705
E0205 19:41:17.248671 3430985 buffer_comparator.cc:156] Difference at 1028: -nan, expected 8.94387
E0205 19:41:17.248675 3430985 buffer_comparator.cc:156] Difference at 1029: -nan, expected 10.8997
E0205 19:41:17.248678 3430985 buffer_comparator.cc:156] Difference at 1030: -nan, expected 10.6486
E0205 19:41:17.248682 3430985 buffer_comparator.cc:156] Difference at 1031: -nan, expected 9.73507
E0205 19:41:17.248686 3430985 buffer_comparator.cc:156] Difference at 1032: -nan, expected 12.2806
E0205 19:41:17.248690 3430985 buffer_comparator.cc:156] Difference at 1033: -nan, expected 10.1883
2025-02-05 19:41:17.248696: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1081] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0205 19:41:17.251263 3430985 buffer_comparator.cc:156] Difference at 1040: -nan, expected 9.99799
E0205 19:41:17.251293 3430985 buffer_comparator.cc:156] Difference at 1041: -nan, expected 12.209
E0205 19:41:17.251298 3430985 buffer_comparator.cc:156] Difference at 1042: -nan, expected 9.4851
E0205 19:41:17.251301 3430985 buffer_comparator.cc:156] Difference at 1043: -nan, expected 8.26397
E0205 19:41:17.251305 3430985 buffer_comparator.cc:156] Difference at 1044: -nan, expected 11.9253
E0205 19:41:17.251311 3430985 buffer_comparator.cc:156] Difference at 1045: -nan, expected 8.99047
E0205 19:41:17.251314 3430985 buffer_comparator.cc:156] Difference at 1046: -nan, expected 8.81842
E0205 19:41:17.251318 3430985 buffer_comparator.cc:156] Difference at 1047: -nan, expected 12.2714
E0205 19:41:17.251322 3430985 buffer_comparator.cc:156] Difference at 1048: -nan, expected 11.1417
E0205 19:41:17.251326 3430985 buffer_comparator.cc:156] Difference at 1049: -nan, expected 10.6572
2025-02-05 19:41:17.251332: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1081] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0205 19:41:17.253908 3430985 buffer_comparator.cc:156] Difference at 1056: -nan, expected 10.6543
E0205 19:41:17.253926 3430985 buffer_comparator.cc:156] Difference at 1057: -nan, expected 11.0945
E0205 19:41:17.253930 3430985 buffer_comparator.cc:156] Difference at 1058: -nan, expected 11.1424
E0205 19:41:17.253934 3430985 buffer_comparator.cc:156] Difference at 1059: -nan, expected 12.7556
E0205 19:41:17.253938 3430985 buffer_comparator.cc:156] Difference at 1060: -nan, expected 12.6932
E0205 19:41:17.253942 3430985 buffer_comparator.cc:156] Difference at 1061: -nan, expected 10.0594
E0205 19:41:17.253946 3430985 buffer_comparator.cc:156] Difference at 1062: -nan, expected 12.3478
E0205 19:41:17.253949 3430985 buffer_comparator.cc:156] Difference at 1063: -nan, expected 10.8381
E0205 19:41:17.253953 3430985 buffer_comparator.cc:156] Difference at 1064: -nan, expected 10.409
E0205 19:41:17.253957 3430985 buffer_comparator.cc:156] Difference at 1065: -nan, expected 10.3688
2025-02-05 19:41:17.253963: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1081] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0205 19:41:17.256511 3430985 buffer_comparator.cc:156] Difference at 1056: -nan, expected 10.6543
E0205 19:41:17.256526 3430985 buffer_comparator.cc:156] Difference at 1057: -nan, expected 11.0945
E0205 19:41:17.256529 3430985 buffer_comparator.cc:156] Difference at 1058: -nan, expected 11.1424
E0205 19:41:17.256531 3430985 buffer_comparator.cc:156] Difference at 1059: -nan, expected 12.7556
E0205 19:41:17.256534 3430985 buffer_comparator.cc:156] Difference at 1060: -nan, expected 12.6932
E0205 19:41:17.256536 3430985 buffer_comparator.cc:156] Difference at 1061: -nan, expected 10.0594
E0205 19:41:17.256539 3430985 buffer_comparator.cc:156] Difference at 1062: -nan, expected 12.3478
E0205 19:41:17.256541 3430985 buffer_comparator.cc:156] Difference at 1063: -nan, expected 10.8381
E0205 19:41:17.256544 3430985 buffer_comparator.cc:156] Difference at 1064: -nan, expected 10.409
E0205 19:41:17.256546 3430985 buffer_comparator.cc:156] Difference at 1065: -nan, expected 10.3688
2025-02-05 19:41:17.256550: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1081] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0205 19:41:17.259067 3430985 buffer_comparator.cc:156] Difference at 1056: -nan, expected 10.6543
E0205 19:41:17.259081 3430985 buffer_comparator.cc:156] Difference at 1057: -nan, expected 11.0945
E0205 19:41:17.259083 3430985 buffer_comparator.cc:156] Difference at 1058: -nan, expected 11.1424
E0205 19:41:17.259086 3430985 buffer_comparator.cc:156] Difference at 1059: -nan, expected 12.7556
E0205 19:41:17.259089 3430985 buffer_comparator.cc:156] Difference at 1060: -nan, expected 12.6932
E0205 19:41:17.259091 3430985 buffer_comparator.cc:156] Difference at 1061: -nan, expected 10.0594
E0205 19:41:17.259093 3430985 buffer_comparator.cc:156] Difference at 1062: -nan, expected 12.3478
E0205 19:41:17.259096 3430985 buffer_comparator.cc:156] Difference at 1063: -nan, expected 10.8381
E0205 19:41:17.259099 3430985 buffer_comparator.cc:156] Difference at 1064: -nan, expected 10.409
E0205 19:41:17.259101 3430985 buffer_comparator.cc:156] Difference at 1065: -nan, expected 10.3688
2025-02-05 19:41:17.259107: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1081] Results do not match the reference. This is likely a bug/unexpected loss of precision.
Iteration: [     1/ 50000] 	 Loss: 3.158992767 (3.158992767) 	 Physics Loss: 1.982357264 (1.982357264) 	 Data Loss: 0.578243732 (0.578243732) 	 BC Loss: 0.598391712 (0.598391712)
Iteration: [  1001/ 50000] 	 Loss: 0.034934171 (0.028678153) 	 Physics Loss: 0.000550666 (0.000427971) 	 Data Loss: 0.022407684 (0.011849238) 	 BC Loss: 0.011975820 (0.016400939)
Iteration: [  2001/ 50000] 	 Loss: 0.029342793 (0.031756539) 	 Physics Loss: 0.001803906 (0.001986223) 	 Data Loss: 0.013016323 (0.012399685) 	 BC Loss: 0.014522564 (0.017370628)
Iteration: [  3001/ 50000] 	 Loss: 0.021532167 (0.027456025) 	 Physics Loss: 0.002033974 (0.004514552) 	 Data Loss: 0.004263131 (0.008210877) 	 BC Loss: 0.015235063 (0.014730594)
Iteration: [  4001/ 50000] 	 Loss: 0.049741872 (0.035063371) 	 Physics Loss: 0.008673832 (0.002975470) 	 Data Loss: 0.016625574 (0.012557827) 	 BC Loss: 0.024442466 (0.019530077)
Iteration: [  5001/ 50000] 	 Loss: 0.021555956 (0.037785888) 	 Physics Loss: 0.002584497 (0.009353531) 	 Data Loss: 0.008988982 (0.011811271) 	 BC Loss: 0.009982477 (0.016621085)
Iteration: [  6001/ 50000] 	 Loss: 0.022884602 (0.019320106) 	 Physics Loss: 0.000503887 (0.000844118) 	 Data Loss: 0.005671175 (0.006657102) 	 BC Loss: 0.016709540 (0.011818888)
Iteration: [  7001/ 50000] 	 Loss: 0.017019382 (0.019912010) 	 Physics Loss: 0.000786213 (0.000926178) 	 Data Loss: 0.006546173 (0.007675517) 	 BC Loss: 0.009686996 (0.011310317)
Iteration: [  8001/ 50000] 	 Loss: 0.022041641 (0.018215429) 	 Physics Loss: 0.002748450 (0.002823828) 	 Data Loss: 0.004308246 (0.004818310) 	 BC Loss: 0.014984946 (0.010573289)
Iteration: [  9001/ 50000] 	 Loss: 0.020825051 (0.016353965) 	 Physics Loss: 0.001645100 (0.001697728) 	 Data Loss: 0.003023159 (0.004666437) 	 BC Loss: 0.016156793 (0.009989802)
Iteration: [ 10001/ 50000] 	 Loss: 0.012419771 (0.014550619) 	 Physics Loss: 0.001211297 (0.002143627) 	 Data Loss: 0.006038338 (0.003847382) 	 BC Loss: 0.005170135 (0.008559611)
Iteration: [ 11001/ 50000] 	 Loss: 0.016400268 (0.013275324) 	 Physics Loss: 0.001640767 (0.001365761) 	 Data Loss: 0.003816272 (0.003336374) 	 BC Loss: 0.010943229 (0.008573188)
Iteration: [ 12001/ 50000] 	 Loss: 0.012625601 (0.012438955) 	 Physics Loss: 0.001496360 (0.001293942) 	 Data Loss: 0.005017307 (0.003183292) 	 BC Loss: 0.006111934 (0.007961722)
Iteration: [ 13001/ 50000] 	 Loss: 0.006276353 (0.011658882) 	 Physics Loss: 0.000723396 (0.001394610) 	 Data Loss: 0.001822183 (0.002704730) 	 BC Loss: 0.003730775 (0.007559542)
Iteration: [ 14001/ 50000] 	 Loss: 0.007330933 (0.010494760) 	 Physics Loss: 0.001231446 (0.001130134) 	 Data Loss: 0.003328452 (0.002512253) 	 BC Loss: 0.002771034 (0.006852372)
Iteration: [ 15001/ 50000] 	 Loss: 0.012322948 (0.010581179) 	 Physics Loss: 0.001112666 (0.001342141) 	 Data Loss: 0.000318603 (0.002849674) 	 BC Loss: 0.010891679 (0.006389363)
Iteration: [ 16001/ 50000] 	 Loss: 0.005651589 (0.011476245) 	 Physics Loss: 0.001261572 (0.001271797) 	 Data Loss: 0.001857541 (0.002460307) 	 BC Loss: 0.002532476 (0.007744140)
Iteration: [ 17001/ 50000] 	 Loss: 0.004589280 (0.010282698) 	 Physics Loss: 0.000900613 (0.001197141) 	 Data Loss: 0.000789952 (0.002489067) 	 BC Loss: 0.002898715 (0.006596491)
Iteration: [ 18001/ 50000] 	 Loss: 0.013593317 (0.008794412) 	 Physics Loss: 0.001301785 (0.001351097) 	 Data Loss: 0.003135182 (0.001751710) 	 BC Loss: 0.009156350 (0.005691605)
Iteration: [ 19001/ 50000] 	 Loss: 0.009044455 (0.009721650) 	 Physics Loss: 0.001797066 (0.001625932) 	 Data Loss: 0.003098861 (0.001854344) 	 BC Loss: 0.004148528 (0.006241375)
Iteration: [ 20001/ 50000] 	 Loss: 0.004531536 (0.007789471) 	 Physics Loss: 0.002314395 (0.001649175) 	 Data Loss: 0.001491809 (0.001308756) 	 BC Loss: 0.000725332 (0.004831538)
Iteration: [ 21001/ 50000] 	 Loss: 0.004263383 (0.006220151) 	 Physics Loss: 0.001818714 (0.001836546) 	 Data Loss: 0.001014963 (0.001413772) 	 BC Loss: 0.001429706 (0.002969833)
Iteration: [ 22001/ 50000] 	 Loss: 0.006228818 (0.005886041) 	 Physics Loss: 0.003274539 (0.002052142) 	 Data Loss: 0.002419390 (0.000916688) 	 BC Loss: 0.000534889 (0.002917211)
Iteration: [ 23001/ 50000] 	 Loss: 0.004727511 (0.004308250) 	 Physics Loss: 0.001700793 (0.001910425) 	 Data Loss: 0.001145163 (0.000927515) 	 BC Loss: 0.001881554 (0.001470311)
Iteration: [ 24001/ 50000] 	 Loss: 0.005070512 (0.003963022) 	 Physics Loss: 0.001458622 (0.001738792) 	 Data Loss: 0.000377079 (0.000693607) 	 BC Loss: 0.003234812 (0.001530622)
Iteration: [ 25001/ 50000] 	 Loss: 0.002241082 (0.003130429) 	 Physics Loss: 0.000983181 (0.001482069) 	 Data Loss: 0.000401148 (0.000641221) 	 BC Loss: 0.000856754 (0.001007139)
Iteration: [ 26001/ 50000] 	 Loss: 0.002955514 (0.002503448) 	 Physics Loss: 0.001489272 (0.001270391) 	 Data Loss: 0.000252445 (0.000556399) 	 BC Loss: 0.001213797 (0.000676658)
Iteration: [ 27001/ 50000] 	 Loss: 0.003096203 (0.002668173) 	 Physics Loss: 0.001007635 (0.001394353) 	 Data Loss: 0.000529087 (0.000551110) 	 BC Loss: 0.001559482 (0.000722710)
Iteration: [ 28001/ 50000] 	 Loss: 0.001113268 (0.002219696) 	 Physics Loss: 0.000669667 (0.001248004) 	 Data Loss: 0.000210634 (0.000505303) 	 BC Loss: 0.000232967 (0.000466388)
Iteration: [ 29001/ 50000] 	 Loss: 0.001584558 (0.001840178) 	 Physics Loss: 0.000857728 (0.000936380) 	 Data Loss: 0.000405950 (0.000491625) 	 BC Loss: 0.000320880 (0.000412173)
Iteration: [ 30001/ 50000] 	 Loss: 0.001524673 (0.001886792) 	 Physics Loss: 0.000554989 (0.001022693) 	 Data Loss: 0.000461125 (0.000466423) 	 BC Loss: 0.000508558 (0.000397676)
Iteration: [ 31001/ 50000] 	 Loss: 0.001906899 (0.002045441) 	 Physics Loss: 0.000959212 (0.001217105) 	 Data Loss: 0.000290614 (0.000451727) 	 BC Loss: 0.000657073 (0.000376609)
Iteration: [ 32001/ 50000] 	 Loss: 0.003575745 (0.002317760) 	 Physics Loss: 0.002933637 (0.001479890) 	 Data Loss: 0.000278991 (0.000500438) 	 BC Loss: 0.000363118 (0.000337431)
Iteration: [ 33001/ 50000] 	 Loss: 0.001054667 (0.001392871) 	 Physics Loss: 0.000542178 (0.000709339) 	 Data Loss: 0.000279479 (0.000395160) 	 BC Loss: 0.000233010 (0.000288372)
Iteration: [ 34001/ 50000] 	 Loss: 0.001258306 (0.001430126) 	 Physics Loss: 0.000604900 (0.000740890) 	 Data Loss: 0.000279625 (0.000439273) 	 BC Loss: 0.000373781 (0.000249964)
Iteration: [ 35001/ 50000] 	 Loss: 0.001458327 (0.001451151) 	 Physics Loss: 0.001126388 (0.000859504) 	 Data Loss: 0.000209119 (0.000364929) 	 BC Loss: 0.000122820 (0.000226719)
Iteration: [ 36001/ 50000] 	 Loss: 0.001965113 (0.001363893) 	 Physics Loss: 0.000819583 (0.000818373) 	 Data Loss: 0.000884497 (0.000343106) 	 BC Loss: 0.000261034 (0.000202413)
Iteration: [ 37001/ 50000] 	 Loss: 0.001001803 (0.001306025) 	 Physics Loss: 0.000458097 (0.000750440) 	 Data Loss: 0.000366133 (0.000375299) 	 BC Loss: 0.000177574 (0.000180287)
Iteration: [ 38001/ 50000] 	 Loss: 0.001594038 (0.001202732) 	 Physics Loss: 0.000834267 (0.000656006) 	 Data Loss: 0.000438084 (0.000355202) 	 BC Loss: 0.000321687 (0.000191524)
Iteration: [ 39001/ 50000] 	 Loss: 0.001096051 (0.001220055) 	 Physics Loss: 0.000612216 (0.000665410) 	 Data Loss: 0.000219414 (0.000392687) 	 BC Loss: 0.000264422 (0.000161957)
Iteration: [ 40001/ 50000] 	 Loss: 0.001038662 (0.001444635) 	 Physics Loss: 0.000485165 (0.000887432) 	 Data Loss: 0.000353080 (0.000370217) 	 BC Loss: 0.000200418 (0.000186986)
Iteration: [ 41001/ 50000] 	 Loss: 0.000912517 (0.001242174) 	 Physics Loss: 0.000559338 (0.000678301) 	 Data Loss: 0.000281135 (0.000411240) 	 BC Loss: 0.000072045 (0.000152633)
Iteration: [ 42001/ 50000] 	 Loss: 0.001224924 (0.001465745) 	 Physics Loss: 0.001037403 (0.000994991) 	 Data Loss: 0.000128774 (0.000311622) 	 BC Loss: 0.000058747 (0.000159133)
Iteration: [ 43001/ 50000] 	 Loss: 0.000964010 (0.001390206) 	 Physics Loss: 0.000636658 (0.000894072) 	 Data Loss: 0.000201709 (0.000336276) 	 BC Loss: 0.000125642 (0.000159858)
Iteration: [ 44001/ 50000] 	 Loss: 0.000622213 (0.001140883) 	 Physics Loss: 0.000339070 (0.000657654) 	 Data Loss: 0.000136122 (0.000339579) 	 BC Loss: 0.000147021 (0.000143649)
Iteration: [ 45001/ 50000] 	 Loss: 0.000991343 (0.001148664) 	 Physics Loss: 0.000418792 (0.000707322) 	 Data Loss: 0.000449356 (0.000315701) 	 BC Loss: 0.000123195 (0.000125641)
Iteration: [ 46001/ 50000] 	 Loss: 0.000918692 (0.000882140) 	 Physics Loss: 0.000450692 (0.000441957) 	 Data Loss: 0.000271820 (0.000319554) 	 BC Loss: 0.000196180 (0.000120629)
Iteration: [ 47001/ 50000] 	 Loss: 0.001575661 (0.001149394) 	 Physics Loss: 0.001348543 (0.000694363) 	 Data Loss: 0.000169103 (0.000319502) 	 BC Loss: 0.000058015 (0.000135529)
Iteration: [ 48001/ 50000] 	 Loss: 0.001518319 (0.001056366) 	 Physics Loss: 0.001149551 (0.000625095) 	 Data Loss: 0.000176694 (0.000317998) 	 BC Loss: 0.000192074 (0.000113272)
Iteration: [ 49001/ 50000] 	 Loss: 0.002202400 (0.001657575) 	 Physics Loss: 0.001789718 (0.001203460) 	 Data Loss: 0.000310050 (0.000312800) 	 BC Loss: 0.000102632 (0.000141316)

Visualizing the Results

julia
ts, xs, ys = 0.0f0:0.05f0:2.0f0, 0.0f0:0.02f0:2.0f0, 0.0f0:0.02f0:2.0f0
grid = stack([[elem...] for elem in vec(collect(Iterators.product(xs, ys, ts)))])

u_real = reshape(analytical_solution(grid), length(xs), length(ys), length(ts))

grid_normalized = (grid .- minimum(grid)) ./ (maximum(grid) .- minimum(grid))
u_pred = reshape(trained_u(grid_normalized), length(xs), length(ys), length(ts))
u_pred = u_pred .* (max_pde_val - min_pde_val) .+ min_pde_val

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
    errs = [abs.(u_pred[:, :, i] .- u_real[:, :, i]) for i in 1:length(ts)]
    Colorbar(fig[1, 2]; limits=extrema(stack(errs)))

    CairoMakie.record(fig, "pinn_nested_ad.gif", 1:length(ts); framerate=10) do i
        ax.title = "Abs. Predictor Error | Time: $(ts[i])"
        err = errs[i]
        contour!(ax, xs, ys, err; levels=10, linewidth=2)
        heatmap!(ax, xs, ys, err)
        return fig
    end

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.