Skip to content

Training a PINN on 2D PDE

In this tutorial we will go over using a PINN to solve 2D PDEs. We will be using the system from NeuralPDE Tutorials. However, we will be using our custom loss function and use nested AD capabilities of Lux.jl.

This is a demonstration of Lux.jl. For serious usecases of PINNs, please refer to the package: NeuralPDE.jl.

Package Imports

julia
using Lux, Optimisers, Random, Printf, Statistics, MLUtils, OnlineStats, CairoMakie,
    Reactant, Enzyme

const xdev = reactant_device(; force = true)
const cdev = cpu_device()
(::MLDataDevices.CPUDevice) (generic function with 1 method)

Problem Definition

Since Lux supports efficient nested AD upto 2nd order, we will rewrite the problem with first order derivatives, so that we can compute the gradients of the loss using 2nd order AD.

Define the Neural Networks

All the networks take 3 input variables and output a scalar value. Here, we will define a a wrapper over the 3 networks, so that we can train them using Training.TrainState.

julia
struct PINN{U, V, W} <: Lux.AbstractLuxContainerLayer{(:u, :v, :w)}
    u::U
    v::V
    w::W
end

function create_mlp(act, hidden_dims)
    return Chain(
        Dense(3 => hidden_dims, act),
        Dense(hidden_dims => hidden_dims, act),
        Dense(hidden_dims => hidden_dims, act),
        Dense(hidden_dims => 1)
    )
end

function PINN(; hidden_dims::Int = 32)
    return PINN(
        create_mlp(tanh, hidden_dims),
        create_mlp(tanh, hidden_dims),
        create_mlp(tanh, hidden_dims)
    )
end
Main.var"##230".PINN

Define the Loss Functions

We will define a custom loss function to compute the loss using 2nd order AD. We will use the following loss function

julia
@views function physics_informed_loss_function(
        u::StatefulLuxLayer, v::StatefulLuxLayer, w::StatefulLuxLayer, xyt::AbstractArray
    )
    ∂u_∂xyt = Enzyme.gradient(Enzyme.Reverse, sum  u, xyt)[1]
    ∂u_∂x, ∂u_∂y, ∂u_∂t = ∂u_∂xyt[1:1, :], ∂u_∂xyt[2:2, :], ∂u_∂xyt[3:3, :]
    ∂v_∂x = Enzyme.gradient(Enzyme.Reverse, sum  v, xyt)[1][1:1, :]
    v_xyt = v(xyt)
    ∂w_∂y = Enzyme.gradient(Enzyme.Reverse, sum  w, xyt)[1][2:2, :]
    w_xyt = w(xyt)
    return (
        mean(abs2, ∂u_∂t .- ∂v_∂x .- ∂w_∂y) +
            mean(abs2, v_xyt .- ∂u_∂x) +
            mean(abs2, w_xyt .- ∂u_∂y)
    )
end
physics_informed_loss_function (generic function with 1 method)

Additionally, we need to compute the loss wrt the boundary conditions.

julia
function mse_loss_function(u::StatefulLuxLayer, target::AbstractArray, xyt::AbstractArray)
    return MSELoss()(u(xyt), target)
end

function loss_function(model, ps, st, (xyt, target_data, xyt_bc, target_bc))
    u_net = StatefulLuxLayer{true}(model.u, ps.u, st.u)
    v_net = StatefulLuxLayer{true}(model.v, ps.v, st.v)
    w_net = StatefulLuxLayer{true}(model.w, ps.w, st.w)
    physics_loss = physics_informed_loss_function(u_net, v_net, w_net, xyt)
    data_loss = mse_loss_function(u_net, target_data, xyt)
    bc_loss = mse_loss_function(u_net, target_bc, xyt_bc)
    loss = physics_loss + data_loss + bc_loss
    return (
        loss,
        (; u = u_net.st, v = v_net.st, w = w_net.st),
        (; physics_loss, data_loss, bc_loss),
    )
end
loss_function (generic function with 1 method)

Generate the Data

We will generate some random data to train the model on. We will take data on a square spatial and temporal domain x[0,2], y[0,2], and t[0,2]. Typically, you want to be smarter about the sampling process, but for the sake of simplicity, we will skip that.

julia
analytical_solution(x, y, t) = @. exp(x + y) * cos(x + y + 4t)
analytical_solution(xyt) = analytical_solution(xyt[1, :], xyt[2, :], xyt[3, :])

begin
    grid_len = 16

    grid = range(0.0f0, 2.0f0; length = grid_len)
    xyt = stack([[elem...] for elem in vec(collect(Iterators.product(grid, grid, grid)))])

    target_data = reshape(analytical_solution(xyt), 1, :)

    bc_len = 512

    x = collect(range(0.0f0, 2.0f0; length = bc_len))
    y = collect(range(0.0f0, 2.0f0; length = bc_len))
    t = collect(range(0.0f0, 2.0f0; length = bc_len))

    xyt_bc = hcat(
        stack((x, y, zeros(Float32, bc_len)); dims = 1),
        stack((zeros(Float32, bc_len), y, t); dims = 1),
        stack((ones(Float32, bc_len) .* 2, y, t); dims = 1),
        stack((x, zeros(Float32, bc_len), t); dims = 1),
        stack((x, ones(Float32, bc_len) .* 2, t); dims = 1)
    )
    target_bc = reshape(analytical_solution(xyt_bc), 1, :)

    min_target_bc, max_target_bc = extrema(target_bc)
    min_data, max_data = extrema(target_data)
    min_pde_val, max_pde_val = min(min_data, min_target_bc), max(max_data, max_target_bc)

    xyt = (xyt .- minimum(xyt)) ./ (maximum(xyt) .- minimum(xyt))
    xyt_bc = (xyt_bc .- minimum(xyt_bc)) ./ (maximum(xyt_bc) .- minimum(xyt_bc))
    target_bc = (target_bc .- min_pde_val) ./ (max_pde_val - min_pde_val)
    target_data = (target_data .- min_pde_val) ./ (max_pde_val - min_pde_val)
end

Training

julia
function train_model(
        xyt, target_data, xyt_bc, target_bc; seed::Int = 0,
        maxiters::Int = 50000, hidden_dims::Int = 32
    )
    rng = Random.default_rng()
    Random.seed!(rng, seed)

    pinn = PINN(; hidden_dims)
    ps, st = Lux.setup(rng, pinn) |> xdev

    bc_dataloader = DataLoader(
        (xyt_bc, target_bc); batchsize = 32, shuffle = true, partial = false
    ) |> xdev
    pde_dataloader = DataLoader(
        (xyt, target_data); batchsize = 32, shuffle = true, partial = false
    ) |> xdev

    train_state = Training.TrainState(pinn, ps, st, Adam(0.05f0))
    lr = i -> i < 5000 ? 0.05f0 : (i < 10000 ? 0.005f0 : 0.0005f0)

    total_loss_tracker, physics_loss_tracker, data_loss_tracker, bc_loss_tracker = ntuple(
        _ -> OnlineStats.CircBuff(Float32, 32; rev = true), 4
    )

    iter = 1
    for ((xyt_batch, target_data_batch), (xyt_bc_batch, target_bc_batch)) in zip(
            Iterators.cycle(pde_dataloader), Iterators.cycle(bc_dataloader)
        )
        Optimisers.adjust!(train_state, lr(iter))

        _, loss, stats, train_state = Training.single_train_step!(
            AutoEnzyme(), loss_function,
            (xyt_batch, target_data_batch, xyt_bc_batch, target_bc_batch),
            train_state
        )

        fit!(total_loss_tracker, Float32(loss))
        fit!(physics_loss_tracker, Float32(stats.physics_loss))
        fit!(data_loss_tracker, Float32(stats.data_loss))
        fit!(bc_loss_tracker, Float32(stats.bc_loss))

        mean_loss = mean(OnlineStats.value(total_loss_tracker))
        mean_physics_loss = mean(OnlineStats.value(physics_loss_tracker))
        mean_data_loss = mean(OnlineStats.value(data_loss_tracker))
        mean_bc_loss = mean(OnlineStats.value(bc_loss_tracker))

        isnan(loss) && throw(ArgumentError("NaN Loss Detected"))

        if iter % 1000 == 1 || iter == maxiters
            @printf "Iteration: [%6d/%6d] \t Loss: %.9f (%.9f) \t Physics Loss: %.9f \
                     (%.9f) \t Data Loss: %.9f (%.9f) \t BC \
                     Loss: %.9f (%.9f)\n" iter maxiters loss mean_loss stats.physics_loss mean_physics_loss stats.data_loss mean_data_loss stats.bc_loss mean_bc_loss
        end

        iter += 1
        iter  maxiters && break
    end

    return StatefulLuxLayer{true}(
        pinn, cdev(train_state.parameters), cdev(train_state.states)
    )
end

trained_model = train_model(xyt, target_data, xyt_bc, target_bc)
trained_u = Lux.testmode(
    StatefulLuxLayer{true}(trained_model.model.u, trained_model.ps.u, trained_model.st.u)
)
2025-03-08 00:04:02.170188: I external/xla/xla/service/service.cc:152] XLA service 0xbd998f0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-03-08 00:04:02.170234: I external/xla/xla/service/service.cc:160]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1741392242.171025 3898674 se_gpu_pjrt_client.cc:951] Using BFC allocator.
I0000 00:00:1741392242.171087 3898674 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1741392242.171123 3898674 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1741392242.182906 3898674 cuda_dnn.cc:529] Loaded cuDNN version 90400
E0000 00:00:1741392540.614653 3898674 buffer_comparator.cc:156] Difference at 16: 0, expected 11.6059
E0000 00:00:1741392540.614713 3898674 buffer_comparator.cc:156] Difference at 17: 0, expected 14.502
E0000 00:00:1741392540.614720 3898674 buffer_comparator.cc:156] Difference at 18: 0, expected 11.2449
E0000 00:00:1741392540.614727 3898674 buffer_comparator.cc:156] Difference at 19: 0, expected 10.0998
E0000 00:00:1741392540.614734 3898674 buffer_comparator.cc:156] Difference at 20: 0, expected 14.0222
E0000 00:00:1741392540.614740 3898674 buffer_comparator.cc:156] Difference at 21: 0, expected 10.1321
E0000 00:00:1741392540.614747 3898674 buffer_comparator.cc:156] Difference at 22: 0, expected 10.2986
E0000 00:00:1741392540.614753 3898674 buffer_comparator.cc:156] Difference at 23: 0, expected 14.1109
E0000 00:00:1741392540.614760 3898674 buffer_comparator.cc:156] Difference at 24: 0, expected 13.3463
E0000 00:00:1741392540.614766 3898674 buffer_comparator.cc:156] Difference at 25: 0, expected 12.8369
2025-03-08 00:09:00.614782: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741392540.618180 3898674 buffer_comparator.cc:156] Difference at 16: 0, expected 11.6059
E0000 00:00:1741392540.618232 3898674 buffer_comparator.cc:156] Difference at 17: 0, expected 14.502
E0000 00:00:1741392540.618237 3898674 buffer_comparator.cc:156] Difference at 18: 0, expected 11.2449
E0000 00:00:1741392540.618242 3898674 buffer_comparator.cc:156] Difference at 19: 0, expected 10.0998
E0000 00:00:1741392540.618246 3898674 buffer_comparator.cc:156] Difference at 20: 0, expected 14.0222
E0000 00:00:1741392540.618250 3898674 buffer_comparator.cc:156] Difference at 21: 0, expected 10.1321
E0000 00:00:1741392540.618255 3898674 buffer_comparator.cc:156] Difference at 22: 0, expected 10.2986
E0000 00:00:1741392540.618259 3898674 buffer_comparator.cc:156] Difference at 23: 0, expected 14.1109
E0000 00:00:1741392540.618263 3898674 buffer_comparator.cc:156] Difference at 24: 0, expected 13.3463
E0000 00:00:1741392540.618268 3898674 buffer_comparator.cc:156] Difference at 25: 0, expected 12.8369
2025-03-08 00:09:00.618279: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741392540.621497 3898674 buffer_comparator.cc:156] Difference at 16: 0, expected 11.6059
E0000 00:00:1741392540.621547 3898674 buffer_comparator.cc:156] Difference at 17: 0, expected 14.502
E0000 00:00:1741392540.621552 3898674 buffer_comparator.cc:156] Difference at 18: 0, expected 11.2449
E0000 00:00:1741392540.621556 3898674 buffer_comparator.cc:156] Difference at 19: 0, expected 10.0998
E0000 00:00:1741392540.621561 3898674 buffer_comparator.cc:156] Difference at 20: 0, expected 14.0222
E0000 00:00:1741392540.621565 3898674 buffer_comparator.cc:156] Difference at 21: 0, expected 10.1321
E0000 00:00:1741392540.621569 3898674 buffer_comparator.cc:156] Difference at 22: 0, expected 10.2986
E0000 00:00:1741392540.621574 3898674 buffer_comparator.cc:156] Difference at 23: 0, expected 14.1109
E0000 00:00:1741392540.621578 3898674 buffer_comparator.cc:156] Difference at 24: 0, expected 13.3463
E0000 00:00:1741392540.621584 3898674 buffer_comparator.cc:156] Difference at 25: 0, expected 12.8369
2025-03-08 00:09:00.621595: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741392540.624680 3898674 buffer_comparator.cc:156] Difference at 32: 0, expected 12.4
E0000 00:00:1741392540.624723 3898674 buffer_comparator.cc:156] Difference at 33: 0, expected 12.9454
E0000 00:00:1741392540.624727 3898674 buffer_comparator.cc:156] Difference at 34: 0, expected 12.9462
E0000 00:00:1741392540.624732 3898674 buffer_comparator.cc:156] Difference at 35: 0, expected 13.9775
E0000 00:00:1741392540.624736 3898674 buffer_comparator.cc:156] Difference at 36: 0, expected 15.0433
E0000 00:00:1741392540.624740 3898674 buffer_comparator.cc:156] Difference at 37: 0, expected 12.0589
E0000 00:00:1741392540.624745 3898674 buffer_comparator.cc:156] Difference at 38: 0, expected 14.4629
E0000 00:00:1741392540.624749 3898674 buffer_comparator.cc:156] Difference at 39: 0, expected 12.7671
E0000 00:00:1741392540.624754 3898674 buffer_comparator.cc:156] Difference at 40: 0, expected 12.3584
E0000 00:00:1741392540.624758 3898674 buffer_comparator.cc:156] Difference at 41: 0, expected 11.6002
2025-03-08 00:09:00.624768: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741392540.627771 3898674 buffer_comparator.cc:156] Difference at 32: 0, expected 12.4
E0000 00:00:1741392540.627807 3898674 buffer_comparator.cc:156] Difference at 33: 0, expected 12.9454
E0000 00:00:1741392540.627810 3898674 buffer_comparator.cc:156] Difference at 34: 0, expected 12.9462
E0000 00:00:1741392540.627813 3898674 buffer_comparator.cc:156] Difference at 35: 0, expected 13.9775
E0000 00:00:1741392540.627816 3898674 buffer_comparator.cc:156] Difference at 36: 0, expected 15.0433
E0000 00:00:1741392540.627819 3898674 buffer_comparator.cc:156] Difference at 37: 0, expected 12.0589
E0000 00:00:1741392540.627822 3898674 buffer_comparator.cc:156] Difference at 38: 0, expected 14.4629
E0000 00:00:1741392540.627825 3898674 buffer_comparator.cc:156] Difference at 39: 0, expected 12.7671
E0000 00:00:1741392540.627828 3898674 buffer_comparator.cc:156] Difference at 40: 0, expected 12.3584
E0000 00:00:1741392540.627830 3898674 buffer_comparator.cc:156] Difference at 41: 0, expected 11.6002
2025-03-08 00:09:00.627837: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741392540.630814 3898674 buffer_comparator.cc:156] Difference at 32: 0, expected 12.4
E0000 00:00:1741392540.630851 3898674 buffer_comparator.cc:156] Difference at 33: 0, expected 12.9454
E0000 00:00:1741392540.630854 3898674 buffer_comparator.cc:156] Difference at 34: 0, expected 12.9462
E0000 00:00:1741392540.630857 3898674 buffer_comparator.cc:156] Difference at 35: 0, expected 13.9775
E0000 00:00:1741392540.630860 3898674 buffer_comparator.cc:156] Difference at 36: 0, expected 15.0433
E0000 00:00:1741392540.630863 3898674 buffer_comparator.cc:156] Difference at 37: 0, expected 12.0589
E0000 00:00:1741392540.630866 3898674 buffer_comparator.cc:156] Difference at 38: 0, expected 14.4629
E0000 00:00:1741392540.630869 3898674 buffer_comparator.cc:156] Difference at 39: 0, expected 12.7671
E0000 00:00:1741392540.630872 3898674 buffer_comparator.cc:156] Difference at 40: 0, expected 12.3584
E0000 00:00:1741392540.630875 3898674 buffer_comparator.cc:156] Difference at 41: 0, expected 11.6002
2025-03-08 00:09:00.630883: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
Iteration: [     1/ 50000] 	 Loss: 3.158992767 (3.158992767) 	 Physics Loss: 1.982357264 (1.982357264) 	 Data Loss: 0.578243732 (0.578243732) 	 BC Loss: 0.598391712 (0.598391712)
Iteration: [  1001/ 50000] 	 Loss: 0.034934171 (0.028678153) 	 Physics Loss: 0.000550666 (0.000427971) 	 Data Loss: 0.022407684 (0.011849238) 	 BC Loss: 0.011975820 (0.016400939)
Iteration: [  2001/ 50000] 	 Loss: 0.029342793 (0.031756539) 	 Physics Loss: 0.001803906 (0.001986223) 	 Data Loss: 0.013016323 (0.012399685) 	 BC Loss: 0.014522564 (0.017370628)
Iteration: [  3001/ 50000] 	 Loss: 0.021532167 (0.027456025) 	 Physics Loss: 0.002033974 (0.004514552) 	 Data Loss: 0.004263131 (0.008210877) 	 BC Loss: 0.015235063 (0.014730594)
Iteration: [  4001/ 50000] 	 Loss: 0.049741872 (0.035063371) 	 Physics Loss: 0.008673832 (0.002975470) 	 Data Loss: 0.016625574 (0.012557827) 	 BC Loss: 0.024442466 (0.019530077)
Iteration: [  5001/ 50000] 	 Loss: 0.021555956 (0.037785888) 	 Physics Loss: 0.002584497 (0.009353531) 	 Data Loss: 0.008988982 (0.011811271) 	 BC Loss: 0.009982477 (0.016621085)
Iteration: [  6001/ 50000] 	 Loss: 0.022884602 (0.019320106) 	 Physics Loss: 0.000503887 (0.000844118) 	 Data Loss: 0.005671175 (0.006657102) 	 BC Loss: 0.016709540 (0.011818888)
Iteration: [  7001/ 50000] 	 Loss: 0.017019382 (0.019912010) 	 Physics Loss: 0.000786213 (0.000926178) 	 Data Loss: 0.006546173 (0.007675517) 	 BC Loss: 0.009686996 (0.011310317)
Iteration: [  8001/ 50000] 	 Loss: 0.022041641 (0.018215429) 	 Physics Loss: 0.002748450 (0.002823828) 	 Data Loss: 0.004308246 (0.004818310) 	 BC Loss: 0.014984946 (0.010573289)
Iteration: [  9001/ 50000] 	 Loss: 0.020825051 (0.016353965) 	 Physics Loss: 0.001645100 (0.001697728) 	 Data Loss: 0.003023159 (0.004666437) 	 BC Loss: 0.016156793 (0.009989802)
Iteration: [ 10001/ 50000] 	 Loss: 0.012419771 (0.014550619) 	 Physics Loss: 0.001211297 (0.002143627) 	 Data Loss: 0.006038338 (0.003847382) 	 BC Loss: 0.005170135 (0.008559611)
Iteration: [ 11001/ 50000] 	 Loss: 0.016400268 (0.013275324) 	 Physics Loss: 0.001640767 (0.001365761) 	 Data Loss: 0.003816272 (0.003336374) 	 BC Loss: 0.010943229 (0.008573188)
Iteration: [ 12001/ 50000] 	 Loss: 0.012625601 (0.012438955) 	 Physics Loss: 0.001496360 (0.001293942) 	 Data Loss: 0.005017307 (0.003183292) 	 BC Loss: 0.006111934 (0.007961722)
Iteration: [ 13001/ 50000] 	 Loss: 0.006276353 (0.011658882) 	 Physics Loss: 0.000723396 (0.001394610) 	 Data Loss: 0.001822183 (0.002704730) 	 BC Loss: 0.003730775 (0.007559542)
Iteration: [ 14001/ 50000] 	 Loss: 0.007330933 (0.010494760) 	 Physics Loss: 0.001231446 (0.001130134) 	 Data Loss: 0.003328452 (0.002512253) 	 BC Loss: 0.002771034 (0.006852372)
Iteration: [ 15001/ 50000] 	 Loss: 0.012322948 (0.010581179) 	 Physics Loss: 0.001112666 (0.001342141) 	 Data Loss: 0.000318603 (0.002849674) 	 BC Loss: 0.010891679 (0.006389363)
Iteration: [ 16001/ 50000] 	 Loss: 0.005651589 (0.011476245) 	 Physics Loss: 0.001261572 (0.001271797) 	 Data Loss: 0.001857541 (0.002460307) 	 BC Loss: 0.002532476 (0.007744140)
Iteration: [ 17001/ 50000] 	 Loss: 0.004589280 (0.010282698) 	 Physics Loss: 0.000900613 (0.001197141) 	 Data Loss: 0.000789952 (0.002489067) 	 BC Loss: 0.002898715 (0.006596491)
Iteration: [ 18001/ 50000] 	 Loss: 0.013593317 (0.008794412) 	 Physics Loss: 0.001301785 (0.001351097) 	 Data Loss: 0.003135182 (0.001751710) 	 BC Loss: 0.009156350 (0.005691605)
Iteration: [ 19001/ 50000] 	 Loss: 0.009044455 (0.009721650) 	 Physics Loss: 0.001797066 (0.001625932) 	 Data Loss: 0.003098861 (0.001854344) 	 BC Loss: 0.004148528 (0.006241375)
Iteration: [ 20001/ 50000] 	 Loss: 0.004531536 (0.007789471) 	 Physics Loss: 0.002314395 (0.001649175) 	 Data Loss: 0.001491809 (0.001308756) 	 BC Loss: 0.000725332 (0.004831538)
Iteration: [ 21001/ 50000] 	 Loss: 0.004263383 (0.006220151) 	 Physics Loss: 0.001818714 (0.001836546) 	 Data Loss: 0.001014963 (0.001413772) 	 BC Loss: 0.001429706 (0.002969833)
Iteration: [ 22001/ 50000] 	 Loss: 0.006228818 (0.005886041) 	 Physics Loss: 0.003274539 (0.002052142) 	 Data Loss: 0.002419390 (0.000916688) 	 BC Loss: 0.000534889 (0.002917211)
Iteration: [ 23001/ 50000] 	 Loss: 0.004727511 (0.004308250) 	 Physics Loss: 0.001700793 (0.001910425) 	 Data Loss: 0.001145163 (0.000927515) 	 BC Loss: 0.001881554 (0.001470311)
Iteration: [ 24001/ 50000] 	 Loss: 0.005070512 (0.003963022) 	 Physics Loss: 0.001458622 (0.001738792) 	 Data Loss: 0.000377079 (0.000693607) 	 BC Loss: 0.003234812 (0.001530622)
Iteration: [ 25001/ 50000] 	 Loss: 0.002241082 (0.003130429) 	 Physics Loss: 0.000983181 (0.001482069) 	 Data Loss: 0.000401148 (0.000641221) 	 BC Loss: 0.000856754 (0.001007139)
Iteration: [ 26001/ 50000] 	 Loss: 0.002955514 (0.002503448) 	 Physics Loss: 0.001489272 (0.001270391) 	 Data Loss: 0.000252445 (0.000556399) 	 BC Loss: 0.001213797 (0.000676658)
Iteration: [ 27001/ 50000] 	 Loss: 0.003096203 (0.002668173) 	 Physics Loss: 0.001007635 (0.001394353) 	 Data Loss: 0.000529087 (0.000551110) 	 BC Loss: 0.001559482 (0.000722710)
Iteration: [ 28001/ 50000] 	 Loss: 0.001113268 (0.002219696) 	 Physics Loss: 0.000669667 (0.001248004) 	 Data Loss: 0.000210634 (0.000505303) 	 BC Loss: 0.000232967 (0.000466388)
Iteration: [ 29001/ 50000] 	 Loss: 0.001584558 (0.001840178) 	 Physics Loss: 0.000857728 (0.000936380) 	 Data Loss: 0.000405950 (0.000491625) 	 BC Loss: 0.000320880 (0.000412173)
Iteration: [ 30001/ 50000] 	 Loss: 0.001524673 (0.001886792) 	 Physics Loss: 0.000554989 (0.001022693) 	 Data Loss: 0.000461125 (0.000466423) 	 BC Loss: 0.000508558 (0.000397676)
Iteration: [ 31001/ 50000] 	 Loss: 0.001906899 (0.002045441) 	 Physics Loss: 0.000959212 (0.001217105) 	 Data Loss: 0.000290614 (0.000451727) 	 BC Loss: 0.000657073 (0.000376609)
Iteration: [ 32001/ 50000] 	 Loss: 0.003575745 (0.002317760) 	 Physics Loss: 0.002933637 (0.001479890) 	 Data Loss: 0.000278991 (0.000500438) 	 BC Loss: 0.000363118 (0.000337431)
Iteration: [ 33001/ 50000] 	 Loss: 0.001054667 (0.001392871) 	 Physics Loss: 0.000542178 (0.000709339) 	 Data Loss: 0.000279479 (0.000395160) 	 BC Loss: 0.000233010 (0.000288372)
Iteration: [ 34001/ 50000] 	 Loss: 0.001258306 (0.001430126) 	 Physics Loss: 0.000604900 (0.000740890) 	 Data Loss: 0.000279625 (0.000439273) 	 BC Loss: 0.000373781 (0.000249964)
Iteration: [ 35001/ 50000] 	 Loss: 0.001458327 (0.001451151) 	 Physics Loss: 0.001126388 (0.000859504) 	 Data Loss: 0.000209119 (0.000364929) 	 BC Loss: 0.000122820 (0.000226719)
Iteration: [ 36001/ 50000] 	 Loss: 0.001965113 (0.001363893) 	 Physics Loss: 0.000819583 (0.000818373) 	 Data Loss: 0.000884497 (0.000343106) 	 BC Loss: 0.000261034 (0.000202413)
Iteration: [ 37001/ 50000] 	 Loss: 0.001001803 (0.001306025) 	 Physics Loss: 0.000458097 (0.000750440) 	 Data Loss: 0.000366133 (0.000375299) 	 BC Loss: 0.000177574 (0.000180287)
Iteration: [ 38001/ 50000] 	 Loss: 0.001594038 (0.001202732) 	 Physics Loss: 0.000834267 (0.000656006) 	 Data Loss: 0.000438084 (0.000355202) 	 BC Loss: 0.000321687 (0.000191524)
Iteration: [ 39001/ 50000] 	 Loss: 0.001096051 (0.001220055) 	 Physics Loss: 0.000612216 (0.000665410) 	 Data Loss: 0.000219414 (0.000392687) 	 BC Loss: 0.000264422 (0.000161957)
Iteration: [ 40001/ 50000] 	 Loss: 0.001038662 (0.001444635) 	 Physics Loss: 0.000485165 (0.000887432) 	 Data Loss: 0.000353080 (0.000370217) 	 BC Loss: 0.000200418 (0.000186986)
Iteration: [ 41001/ 50000] 	 Loss: 0.000912517 (0.001242174) 	 Physics Loss: 0.000559338 (0.000678301) 	 Data Loss: 0.000281135 (0.000411240) 	 BC Loss: 0.000072045 (0.000152633)
Iteration: [ 42001/ 50000] 	 Loss: 0.001224924 (0.001465745) 	 Physics Loss: 0.001037403 (0.000994991) 	 Data Loss: 0.000128774 (0.000311622) 	 BC Loss: 0.000058747 (0.000159133)
Iteration: [ 43001/ 50000] 	 Loss: 0.000964010 (0.001390206) 	 Physics Loss: 0.000636658 (0.000894072) 	 Data Loss: 0.000201709 (0.000336276) 	 BC Loss: 0.000125642 (0.000159858)
Iteration: [ 44001/ 50000] 	 Loss: 0.000622213 (0.001140883) 	 Physics Loss: 0.000339070 (0.000657654) 	 Data Loss: 0.000136122 (0.000339579) 	 BC Loss: 0.000147021 (0.000143649)
Iteration: [ 45001/ 50000] 	 Loss: 0.000991343 (0.001148664) 	 Physics Loss: 0.000418792 (0.000707322) 	 Data Loss: 0.000449356 (0.000315701) 	 BC Loss: 0.000123195 (0.000125641)
Iteration: [ 46001/ 50000] 	 Loss: 0.000918692 (0.000882140) 	 Physics Loss: 0.000450692 (0.000441957) 	 Data Loss: 0.000271820 (0.000319554) 	 BC Loss: 0.000196180 (0.000120629)
Iteration: [ 47001/ 50000] 	 Loss: 0.001575661 (0.001149394) 	 Physics Loss: 0.001348543 (0.000694363) 	 Data Loss: 0.000169103 (0.000319502) 	 BC Loss: 0.000058015 (0.000135529)
Iteration: [ 48001/ 50000] 	 Loss: 0.001518319 (0.001056366) 	 Physics Loss: 0.001149551 (0.000625095) 	 Data Loss: 0.000176694 (0.000317998) 	 BC Loss: 0.000192074 (0.000113272)
Iteration: [ 49001/ 50000] 	 Loss: 0.002202400 (0.001657575) 	 Physics Loss: 0.001789718 (0.001203460) 	 Data Loss: 0.000310050 (0.000312800) 	 BC Loss: 0.000102632 (0.000141316)

Visualizing the Results

julia
ts, xs, ys = 0.0f0:0.05f0:2.0f0, 0.0f0:0.02f0:2.0f0, 0.0f0:0.02f0:2.0f0
grid = stack([[elem...] for elem in vec(collect(Iterators.product(xs, ys, ts)))])

u_real = reshape(analytical_solution(grid), length(xs), length(ys), length(ts))

grid_normalized = (grid .- minimum(grid)) ./ (maximum(grid) .- minimum(grid))
u_pred = reshape(trained_u(grid_normalized), length(xs), length(ys), length(ts))
u_pred = u_pred .* (max_pde_val - min_pde_val) .+ min_pde_val

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel = "x", ylabel = "y")
    errs = [abs.(u_pred[:, :, i] .- u_real[:, :, i]) for i in 1:length(ts)]
    Colorbar(fig[1, 2]; limits = extrema(stack(errs)))

    CairoMakie.record(fig, "pinn_nested_ad.gif", 1:length(ts); framerate = 10) do i
        ax.title = "Abs. Predictor Error | Time: $(ts[i])"
        err = errs[i]
        contour!(ax, xs, ys, err; levels = 10, linewidth = 2)
        heatmap!(ax, xs, ys, err)
        return fig
    end

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.