Skip to content

Training a PINN on 2D PDE

In this tutorial we will go over using a PINN to solve 2D PDEs. We will be using the system from NeuralPDE Tutorials. However, we will be using our custom loss function and use nested AD capabilities of Lux.jl.

This is a demonstration of Lux.jl. For serious usecases of PINNs, please refer to the package: NeuralPDE.jl.

Package Imports

julia
using ADTypes, Lux, Optimisers, Zygote, Random, Printf, Statistics, MLUtils, OnlineStats,
      CairoMakie
using LuxCUDA

CUDA.allowscalar(false)

const gdev = gpu_device()
const cdev = cpu_device()
(::MLDataDevices.CPUDevice) (generic function with 5 methods)

Problem Definition

Since Lux supports efficient nested AD upto 2nd order, we will rewrite the problem with first order derivatives, so that we can compute the gradients of the loss using 2nd order AD.

Define the Neural Networks

All the networks take 3 input variables and output a scalar value. Here, we will define a a wrapper over the 3 networks, so that we can train them using Training.TrainState.

julia
struct PINN{U, V, W} <: Lux.AbstractLuxContainerLayer{(:u, :v, :w)}
    u::U
    v::V
    w::W
end

function create_mlp(act, hidden_dims)
    return Chain(
        Dense(3 => hidden_dims, act),
        Dense(hidden_dims => hidden_dims, act),
        Dense(hidden_dims => hidden_dims, act),
        Dense(hidden_dims => 1)
    )
end

function PINN(; hidden_dims::Int=32)
    return PINN(
        create_mlp(tanh, hidden_dims),
        create_mlp(tanh, hidden_dims),
        create_mlp(tanh, hidden_dims)
    )
end
Main.var"##225".PINN

Define the Loss Functions

We will define a custom loss function to compute the loss using 2nd order AD. We will use the following loss function

julia
@views function physics_informed_loss_function(
        u::StatefulLuxLayer, v::StatefulLuxLayer, w::StatefulLuxLayer, xyt::AbstractArray)
    ∂u_∂xyt = only(Zygote.gradient(sum  u, xyt))
    ∂u_∂x, ∂u_∂y, ∂u_∂t = ∂u_∂xyt[1:1, :], ∂u_∂xyt[2:2, :], ∂u_∂xyt[3:3, :]
    ∂v_∂x = only(Zygote.gradient(sum  v, xyt))[1:1, :]
    v_xyt = v(xyt)
    ∂w_∂y = only(Zygote.gradient(sum  w, xyt))[2:2, :]
    w_xyt = w(xyt)
    return (
        mean(abs2, ∂u_∂t .- ∂v_∂x .- ∂w_∂y) +
        mean(abs2, v_xyt .- ∂u_∂x) +
        mean(abs2, w_xyt .- ∂u_∂y)
    )
end
physics_informed_loss_function (generic function with 1 method)

Additionally, we need to compute the loss wrt the boundary conditions.

julia
function mse_loss_function(u::StatefulLuxLayer, target::AbstractArray, xyt::AbstractArray)
    return MSELoss()(u(xyt), target)
end

function loss_function(model, ps, st, (xyt, target_data, xyt_bc, target_bc))
    u_net = StatefulLuxLayer{true}(model.u, ps.u, st.u)
    v_net = StatefulLuxLayer{true}(model.v, ps.v, st.v)
    w_net = StatefulLuxLayer{true}(model.w, ps.w, st.w)
    physics_loss = physics_informed_loss_function(u_net, v_net, w_net, xyt)
    data_loss = mse_loss_function(u_net, target_data, xyt)
    bc_loss = mse_loss_function(u_net, target_bc, xyt_bc)
    loss = physics_loss + data_loss + bc_loss
    return (
        loss,
        (; u=u_net.st, v=v_net.st, w=w_net.st),
        (; physics_loss, data_loss, bc_loss)
    )
end
loss_function (generic function with 1 method)

Generate the Data

We will generate some random data to train the model on. We will take data on a square spatial and temporal domain x[0,2], y[0,2], and t[0,2]. Typically, you want to be smarter about the sampling process, but for the sake of simplicity, we will skip that.

julia
analytical_solution(x, y, t) = @. exp(x + y) * cos(x + y + 4t)
analytical_solution(xyt) = analytical_solution(xyt[1, :], xyt[2, :], xyt[3, :])

begin
    grid_len = 16

    grid = range(0.0f0, 2.0f0; length=grid_len)
    xyt = stack([[elem...] for elem in vec(collect(Iterators.product(grid, grid, grid)))])

    target_data = reshape(analytical_solution(xyt), 1, :)

    bc_len = 512

    x = collect(range(0.0f0, 2.0f0; length=bc_len))
    y = collect(range(0.0f0, 2.0f0; length=bc_len))
    t = collect(range(0.0f0, 2.0f0; length=bc_len))

    xyt_bc = hcat(
        stack((x, y, zeros(Float32, bc_len)); dims=1),
        stack((zeros(Float32, bc_len), y, t); dims=1),
        stack((ones(Float32, bc_len) .* 2, y, t); dims=1),
        stack((x, zeros(Float32, bc_len), t); dims=1),
        stack((x, ones(Float32, bc_len) .* 2, t); dims=1)
    )
    target_bc = reshape(analytical_solution(xyt_bc), 1, :)

    min_target_bc, max_target_bc = extrema(target_bc)
    min_data, max_data = extrema(target_data)
    min_pde_val, max_pde_val = min(min_data, min_target_bc), max(max_data, max_target_bc)

    xyt = (xyt .- minimum(xyt)) ./ (maximum(xyt) .- minimum(xyt))
    xyt_bc = (xyt_bc .- minimum(xyt_bc)) ./ (maximum(xyt_bc) .- minimum(xyt_bc))
    target_bc = (target_bc .- min_pde_val) ./ (max_pde_val - min_pde_val)
    target_data = (target_data .- min_pde_val) ./ (max_pde_val - min_pde_val)
end

Training

julia
function train_model(xyt, target_data, xyt_bc, target_bc; seed::Int=0,
        maxiters::Int=50000, hidden_dims::Int=32)
    rng = Random.default_rng()
    Random.seed!(rng, seed)

    pinn = PINN(; hidden_dims)
    ps, st = Lux.setup(rng, pinn) |> gdev

    bc_dataloader = DataLoader((xyt_bc, target_bc); batchsize=32, shuffle=true) |> gdev
    pde_dataloader = DataLoader((xyt, target_data); batchsize=32, shuffle=true) |> gdev

    train_state = Training.TrainState(pinn, ps, st, Adam(0.05f0))
    lr = i -> i < 5000 ? 0.05f0 : (i < 10000 ? 0.005f0 : 0.0005f0)

    total_loss_tracker, physics_loss_tracker, data_loss_tracker, bc_loss_tracker = ntuple(
        _ -> Lag(Float32, 32), 4)

    iter = 1
    for ((xyt_batch, target_data_batch), (xyt_bc_batch, target_bc_batch)) in zip(
        Iterators.cycle(pde_dataloader), Iterators.cycle(bc_dataloader))
        Optimisers.adjust!(train_state, lr(iter))

        _, loss, stats, train_state = Training.single_train_step!(
            AutoZygote(), loss_function, (
                xyt_batch, target_data_batch, xyt_bc_batch, target_bc_batch),
            train_state)

        fit!(total_loss_tracker, loss)
        fit!(physics_loss_tracker, stats.physics_loss)
        fit!(data_loss_tracker, stats.data_loss)
        fit!(bc_loss_tracker, stats.bc_loss)

        mean_loss = mean(OnlineStats.value(total_loss_tracker))
        mean_physics_loss = mean(OnlineStats.value(physics_loss_tracker))
        mean_data_loss = mean(OnlineStats.value(data_loss_tracker))
        mean_bc_loss = mean(OnlineStats.value(bc_loss_tracker))

        isnan(loss) && throw(ArgumentError("NaN Loss Detected"))

        if iter % 500 == 1 || iter == maxiters
            @printf "Iteration: [%5d / %5d] \t Loss: %.9f (%.9f) \t Physics Loss: %.9f \
                     (%.9f) \t Data Loss: %.9f (%.9f) \t BC \
                     Loss: %.9f (%.9f)\n" iter maxiters loss mean_loss stats.physics_loss mean_physics_loss stats.data_loss mean_data_loss stats.bc_loss mean_bc_loss
        end

        iter += 1
        iter  maxiters && break
    end

    return StatefulLuxLayer{true}(
        pinn, cdev(train_state.parameters), cdev(train_state.states))
end

trained_model = train_model(xyt, target_data, xyt_bc, target_bc)
trained_u = Lux.testmode(StatefulLuxLayer{true}(
    trained_model.model.u, trained_model.ps.u, trained_model.st.u))
┌ Warning: `Lag(T, b)` is deprecated.  Use `CircBuff(T,b,rev=true)` instead.
│   caller = #6 at 4_PINN2DPDE.md:16 [inlined]
└ @ Core /var/lib/buildkite-agent/builds/gpuci-9/julialang/lux-dot-jl/docs/src/tutorials/intermediate/4_PINN2DPDE.md:16
┌ Warning: `Lag(T, b)` is deprecated.  Use `CircBuff(T,b,rev=true)` instead.
│   caller = #6 at 4_PINN2DPDE.md:16 [inlined]
└ @ Core /var/lib/buildkite-agent/builds/gpuci-9/julialang/lux-dot-jl/docs/src/tutorials/intermediate/4_PINN2DPDE.md:16
┌ Warning: `Lag(T, b)` is deprecated.  Use `CircBuff(T,b,rev=true)` instead.
│   caller = #6 at 4_PINN2DPDE.md:16 [inlined]
└ @ Core /var/lib/buildkite-agent/builds/gpuci-9/julialang/lux-dot-jl/docs/src/tutorials/intermediate/4_PINN2DPDE.md:16
┌ Warning: `Lag(T, b)` is deprecated.  Use `CircBuff(T,b,rev=true)` instead.
│   caller = #6 at 4_PINN2DPDE.md:16 [inlined]
└ @ Core /var/lib/buildkite-agent/builds/gpuci-9/julialang/lux-dot-jl/docs/src/tutorials/intermediate/4_PINN2DPDE.md:16
Iteration: [    1 / 50000] 	 Loss: 3.159042358 (3.159042358) 	 Physics Loss: 1.982162476 (1.982162476) 	 Data Loss: 0.578374863 (0.578374863) 	 BC Loss: 0.598505080 (0.598505080)
Iteration: [  501 / 50000] 	 Loss: 0.040918160 (0.025583776) 	 Physics Loss: 0.000391877 (0.000269295) 	 Data Loss: 0.014243508 (0.009196416) 	 BC Loss: 0.026282774 (0.016118063)
Iteration: [ 1001 / 50000] 	 Loss: 0.015340659 (0.025281426) 	 Physics Loss: 0.000071670 (0.000163182) 	 Data Loss: 0.007905648 (0.010876314) 	 BC Loss: 0.007363341 (0.014241929)
Iteration: [ 1501 / 50000] 	 Loss: 0.019567011 (0.026170366) 	 Physics Loss: 0.001279382 (0.001009728) 	 Data Loss: 0.003071612 (0.010452257) 	 BC Loss: 0.015216017 (0.014708381)
Iteration: [ 2001 / 50000] 	 Loss: 0.035556547 (0.027273500) 	 Physics Loss: 0.004061943 (0.001871931) 	 Data Loss: 0.013011228 (0.010586374) 	 BC Loss: 0.018483378 (0.014815190)
Iteration: [ 2501 / 50000] 	 Loss: 0.011505678 (0.022150228) 	 Physics Loss: 0.002304791 (0.001940841) 	 Data Loss: 0.005615299 (0.007885863) 	 BC Loss: 0.003585588 (0.012323526)
Iteration: [ 3001 / 50000] 	 Loss: 0.031768262 (0.029164422) 	 Physics Loss: 0.008404830 (0.003920683) 	 Data Loss: 0.008808360 (0.010936455) 	 BC Loss: 0.014555071 (0.014307282)
Iteration: [ 3501 / 50000] 	 Loss: 0.017645847 (0.042499166) 	 Physics Loss: 0.001848093 (0.001690981) 	 Data Loss: 0.005461216 (0.018073166) 	 BC Loss: 0.010336538 (0.022735020)
Iteration: [ 4001 / 50000] 	 Loss: 0.028128654 (0.027620461) 	 Physics Loss: 0.005112350 (0.002448601) 	 Data Loss: 0.013499700 (0.011073254) 	 BC Loss: 0.009516605 (0.014098606)
Iteration: [ 4501 / 50000] 	 Loss: 0.014340003 (0.033320315) 	 Physics Loss: 0.001292084 (0.004329988) 	 Data Loss: 0.008556721 (0.012002973) 	 BC Loss: 0.004491198 (0.016987354)
Iteration: [ 5001 / 50000] 	 Loss: 0.030331207 (0.041541956) 	 Physics Loss: 0.000723805 (0.002695386) 	 Data Loss: 0.004466736 (0.016228525) 	 BC Loss: 0.025140665 (0.022618050)
Iteration: [ 5501 / 50000] 	 Loss: 0.022293953 (0.021166507) 	 Physics Loss: 0.000554349 (0.000707158) 	 Data Loss: 0.004542338 (0.007950513) 	 BC Loss: 0.017197266 (0.012508835)
Iteration: [ 6001 / 50000] 	 Loss: 0.018723227 (0.020389127) 	 Physics Loss: 0.000442930 (0.000975201) 	 Data Loss: 0.006906096 (0.007273634) 	 BC Loss: 0.011374202 (0.012140292)
Iteration: [ 6501 / 50000] 	 Loss: 0.028305896 (0.020493284) 	 Physics Loss: 0.000753467 (0.001849154) 	 Data Loss: 0.011062279 (0.008014920) 	 BC Loss: 0.016490150 (0.010629211)
Iteration: [ 7001 / 50000] 	 Loss: 0.015494239 (0.020624701) 	 Physics Loss: 0.007333768 (0.001517879) 	 Data Loss: 0.004043682 (0.008185850) 	 BC Loss: 0.004116789 (0.010920972)
Iteration: [ 7501 / 50000] 	 Loss: 0.019155467 (0.018008159) 	 Physics Loss: 0.002511769 (0.001448493) 	 Data Loss: 0.009514628 (0.006572613) 	 BC Loss: 0.007129070 (0.009987052)
Iteration: [ 8001 / 50000] 	 Loss: 0.018706074 (0.017215939) 	 Physics Loss: 0.001792779 (0.001894138) 	 Data Loss: 0.005628760 (0.005433898) 	 BC Loss: 0.011284535 (0.009887908)
Iteration: [ 8501 / 50000] 	 Loss: 0.017605182 (0.019859709) 	 Physics Loss: 0.002696904 (0.003125851) 	 Data Loss: 0.007372384 (0.005978104) 	 BC Loss: 0.007535893 (0.010755754)
Iteration: [ 9001 / 50000] 	 Loss: 0.017614847 (0.016386209) 	 Physics Loss: 0.003162773 (0.002853032) 	 Data Loss: 0.003553676 (0.004607514) 	 BC Loss: 0.010898398 (0.008925664)
Iteration: [ 9501 / 50000] 	 Loss: 0.008732248 (0.014530762) 	 Physics Loss: 0.001913586 (0.002123826) 	 Data Loss: 0.003329928 (0.004253434) 	 BC Loss: 0.003488734 (0.008153505)
Iteration: [10001 / 50000] 	 Loss: 0.017750096 (0.017799046) 	 Physics Loss: 0.003630795 (0.003061282) 	 Data Loss: 0.003031955 (0.005044522) 	 BC Loss: 0.011087346 (0.009693242)
Iteration: [10501 / 50000] 	 Loss: 0.007509941 (0.011342090) 	 Physics Loss: 0.001571818 (0.001135202) 	 Data Loss: 0.001611185 (0.002613829) 	 BC Loss: 0.004326937 (0.007593059)
Iteration: [11001 / 50000] 	 Loss: 0.022739200 (0.011915382) 	 Physics Loss: 0.000807529 (0.001103305) 	 Data Loss: 0.002606506 (0.003132161) 	 BC Loss: 0.019325165 (0.007679918)
Iteration: [11501 / 50000] 	 Loss: 0.019484218 (0.011457845) 	 Physics Loss: 0.001845157 (0.001325290) 	 Data Loss: 0.002358936 (0.002806861) 	 BC Loss: 0.015280124 (0.007325693)
Iteration: [12001 / 50000] 	 Loss: 0.019427970 (0.011598296) 	 Physics Loss: 0.004434146 (0.001386037) 	 Data Loss: 0.008772802 (0.003136263) 	 BC Loss: 0.006221022 (0.007075997)
Iteration: [12501 / 50000] 	 Loss: 0.012775686 (0.011906523) 	 Physics Loss: 0.001197650 (0.001402911) 	 Data Loss: 0.001059842 (0.002556076) 	 BC Loss: 0.010518193 (0.007947534)
Iteration: [13001 / 50000] 	 Loss: 0.006105572 (0.011037273) 	 Physics Loss: 0.000987124 (0.001454655) 	 Data Loss: 0.001305370 (0.002395016) 	 BC Loss: 0.003813077 (0.007187600)
Iteration: [13501 / 50000] 	 Loss: 0.010004668 (0.011247103) 	 Physics Loss: 0.001224264 (0.001615586) 	 Data Loss: 0.002474443 (0.002668936) 	 BC Loss: 0.006305961 (0.006962581)
Iteration: [14001 / 50000] 	 Loss: 0.009895653 (0.009313912) 	 Physics Loss: 0.001215764 (0.001694928) 	 Data Loss: 0.002037087 (0.002155375) 	 BC Loss: 0.006642802 (0.005463609)
Iteration: [14501 / 50000] 	 Loss: 0.014400685 (0.008724037) 	 Physics Loss: 0.001658659 (0.001564218) 	 Data Loss: 0.003658900 (0.002201537) 	 BC Loss: 0.009083126 (0.004958282)
Iteration: [15001 / 50000] 	 Loss: 0.007676640 (0.008405063) 	 Physics Loss: 0.001566568 (0.001608545) 	 Data Loss: 0.002262217 (0.001837625) 	 BC Loss: 0.003847855 (0.004958895)
Iteration: [15501 / 50000] 	 Loss: 0.004365115 (0.009256126) 	 Physics Loss: 0.001047856 (0.002247394) 	 Data Loss: 0.000648518 (0.002218451) 	 BC Loss: 0.002668741 (0.004790282)
Iteration: [16001 / 50000] 	 Loss: 0.004880759 (0.007501942) 	 Physics Loss: 0.001649067 (0.001671217) 	 Data Loss: 0.000296087 (0.001492234) 	 BC Loss: 0.002935605 (0.004338491)
Iteration: [16501 / 50000] 	 Loss: 0.008074892 (0.007220342) 	 Physics Loss: 0.001825325 (0.001970405) 	 Data Loss: 0.001302087 (0.001730468) 	 BC Loss: 0.004947479 (0.003519468)
Iteration: [17001 / 50000] 	 Loss: 0.005824474 (0.005835793) 	 Physics Loss: 0.002019164 (0.001728887) 	 Data Loss: 0.000624455 (0.001011056) 	 BC Loss: 0.003180855 (0.003095849)
Iteration: [17501 / 50000] 	 Loss: 0.006616294 (0.005807751) 	 Physics Loss: 0.002489866 (0.001884541) 	 Data Loss: 0.001154157 (0.001245214) 	 BC Loss: 0.002972272 (0.002677995)
Iteration: [18001 / 50000] 	 Loss: 0.004335414 (0.005200472) 	 Physics Loss: 0.001764537 (0.002015869) 	 Data Loss: 0.000705982 (0.001058994) 	 BC Loss: 0.001864895 (0.002125610)
Iteration: [18501 / 50000] 	 Loss: 0.004978007 (0.005351806) 	 Physics Loss: 0.002553018 (0.002129085) 	 Data Loss: 0.000752965 (0.001336156) 	 BC Loss: 0.001672024 (0.001886566)
Iteration: [19001 / 50000] 	 Loss: 0.004518208 (0.004657542) 	 Physics Loss: 0.001975382 (0.001959185) 	 Data Loss: 0.000239237 (0.000983244) 	 BC Loss: 0.002303589 (0.001715113)
Iteration: [19501 / 50000] 	 Loss: 0.006942283 (0.004119421) 	 Physics Loss: 0.004590358 (0.001714717) 	 Data Loss: 0.001148389 (0.000894285) 	 BC Loss: 0.001203537 (0.001510418)
Iteration: [20001 / 50000] 	 Loss: 0.008026988 (0.003463188) 	 Physics Loss: 0.005315070 (0.001485110) 	 Data Loss: 0.001919697 (0.000688489) 	 BC Loss: 0.000792221 (0.001289588)
Iteration: [20501 / 50000] 	 Loss: 0.004855067 (0.003504427) 	 Physics Loss: 0.001560361 (0.001747428) 	 Data Loss: 0.000460401 (0.000618138) 	 BC Loss: 0.002834306 (0.001138862)
Iteration: [21001 / 50000] 	 Loss: 0.002417301 (0.003299790) 	 Physics Loss: 0.000924424 (0.001513747) 	 Data Loss: 0.000300984 (0.000659882) 	 BC Loss: 0.001191893 (0.001126162)
Iteration: [21501 / 50000] 	 Loss: 0.003911218 (0.002670718) 	 Physics Loss: 0.002752088 (0.001218763) 	 Data Loss: 0.000835131 (0.000637475) 	 BC Loss: 0.000323999 (0.000814481)
Iteration: [22001 / 50000] 	 Loss: 0.002318957 (0.002511334) 	 Physics Loss: 0.001409405 (0.001309370) 	 Data Loss: 0.000344849 (0.000535839) 	 BC Loss: 0.000564703 (0.000666125)
Iteration: [22501 / 50000] 	 Loss: 0.001994215 (0.002343247) 	 Physics Loss: 0.000746255 (0.001113559) 	 Data Loss: 0.000431564 (0.000577084) 	 BC Loss: 0.000816396 (0.000652603)
Iteration: [23001 / 50000] 	 Loss: 0.002989073 (0.002301548) 	 Physics Loss: 0.001930058 (0.001243673) 	 Data Loss: 0.000805100 (0.000536315) 	 BC Loss: 0.000253915 (0.000521560)
Iteration: [23501 / 50000] 	 Loss: 0.002210422 (0.002371583) 	 Physics Loss: 0.001325711 (0.001179150) 	 Data Loss: 0.000160725 (0.000516864) 	 BC Loss: 0.000723986 (0.000675569)
Iteration: [24001 / 50000] 	 Loss: 0.002023818 (0.002276814) 	 Physics Loss: 0.000857553 (0.001135292) 	 Data Loss: 0.000525994 (0.000580056) 	 BC Loss: 0.000640270 (0.000561465)
Iteration: [24501 / 50000] 	 Loss: 0.002234935 (0.002236427) 	 Physics Loss: 0.001041825 (0.001241080) 	 Data Loss: 0.000286281 (0.000510494) 	 BC Loss: 0.000906830 (0.000484853)
Iteration: [25001 / 50000] 	 Loss: 0.001972227 (0.001982885) 	 Physics Loss: 0.000893347 (0.000971555) 	 Data Loss: 0.000666251 (0.000537486) 	 BC Loss: 0.000412629 (0.000473843)
Iteration: [25501 / 50000] 	 Loss: 0.001659834 (0.002081697) 	 Physics Loss: 0.001085884 (0.001258897) 	 Data Loss: 0.000359855 (0.000474958) 	 BC Loss: 0.000214095 (0.000347842)
Iteration: [26001 / 50000] 	 Loss: 0.002059042 (0.001770784) 	 Physics Loss: 0.001171167 (0.000890729) 	 Data Loss: 0.000426463 (0.000447986) 	 BC Loss: 0.000461412 (0.000432069)
Iteration: [26501 / 50000] 	 Loss: 0.002089650 (0.002068657) 	 Physics Loss: 0.001594031 (0.001356683) 	 Data Loss: 0.000330550 (0.000420685) 	 BC Loss: 0.000165069 (0.000291289)
Iteration: [27001 / 50000] 	 Loss: 0.001346401 (0.001879478) 	 Physics Loss: 0.000798861 (0.001050717) 	 Data Loss: 0.000333685 (0.000522715) 	 BC Loss: 0.000213855 (0.000306047)
Iteration: [27501 / 50000] 	 Loss: 0.001717428 (0.001494603) 	 Physics Loss: 0.000754971 (0.000796866) 	 Data Loss: 0.000408040 (0.000351864) 	 BC Loss: 0.000554417 (0.000345872)
Iteration: [28001 / 50000] 	 Loss: 0.001498525 (0.001789054) 	 Physics Loss: 0.000881248 (0.001008533) 	 Data Loss: 0.000310122 (0.000482373) 	 BC Loss: 0.000307154 (0.000298148)
Iteration: [28501 / 50000] 	 Loss: 0.001300987 (0.001541014) 	 Physics Loss: 0.000685782 (0.000873007) 	 Data Loss: 0.000372513 (0.000378884) 	 BC Loss: 0.000242692 (0.000289123)
Iteration: [29001 / 50000] 	 Loss: 0.001302390 (0.001531350) 	 Physics Loss: 0.000764329 (0.000809421) 	 Data Loss: 0.000326938 (0.000448536) 	 BC Loss: 0.000211123 (0.000273393)
Iteration: [29501 / 50000] 	 Loss: 0.001228882 (0.001639404) 	 Physics Loss: 0.000864926 (0.000892691) 	 Data Loss: 0.000257633 (0.000503808) 	 BC Loss: 0.000106323 (0.000242905)
Iteration: [30001 / 50000] 	 Loss: 0.001362466 (0.001636085) 	 Physics Loss: 0.000635906 (0.000994050) 	 Data Loss: 0.000649345 (0.000412151) 	 BC Loss: 0.000077215 (0.000229884)
Iteration: [30501 / 50000] 	 Loss: 0.000972152 (0.001479216) 	 Physics Loss: 0.000619220 (0.000883626) 	 Data Loss: 0.000230181 (0.000379169) 	 BC Loss: 0.000122752 (0.000216420)
Iteration: [31001 / 50000] 	 Loss: 0.005065940 (0.001380907) 	 Physics Loss: 0.004397592 (0.000831873) 	 Data Loss: 0.000477057 (0.000344165) 	 BC Loss: 0.000191290 (0.000204869)
Iteration: [31501 / 50000] 	 Loss: 0.001720798 (0.001439294) 	 Physics Loss: 0.001153235 (0.000778960) 	 Data Loss: 0.000366380 (0.000444564) 	 BC Loss: 0.000201183 (0.000215771)
Iteration: [32001 / 50000] 	 Loss: 0.001272172 (0.001182971) 	 Physics Loss: 0.000547293 (0.000645338) 	 Data Loss: 0.000612554 (0.000330688) 	 BC Loss: 0.000112325 (0.000206944)
Iteration: [32501 / 50000] 	 Loss: 0.001592739 (0.001303701) 	 Physics Loss: 0.001089040 (0.000732894) 	 Data Loss: 0.000368751 (0.000377450) 	 BC Loss: 0.000134948 (0.000193357)
Iteration: [33001 / 50000] 	 Loss: 0.001257007 (0.001731674) 	 Physics Loss: 0.000964463 (0.001143823) 	 Data Loss: 0.000217459 (0.000395994) 	 BC Loss: 0.000075086 (0.000191858)
Iteration: [33501 / 50000] 	 Loss: 0.001404967 (0.001307026) 	 Physics Loss: 0.000708074 (0.000685339) 	 Data Loss: 0.000530431 (0.000435325) 	 BC Loss: 0.000166462 (0.000186362)
Iteration: [34001 / 50000] 	 Loss: 0.000798481 (0.001060399) 	 Physics Loss: 0.000366831 (0.000566683) 	 Data Loss: 0.000218087 (0.000317301) 	 BC Loss: 0.000213564 (0.000176415)
Iteration: [34501 / 50000] 	 Loss: 0.001798573 (0.001364744) 	 Physics Loss: 0.001362987 (0.000847362) 	 Data Loss: 0.000308547 (0.000354412) 	 BC Loss: 0.000127040 (0.000162970)
Iteration: [35001 / 50000] 	 Loss: 0.000781053 (0.001014936) 	 Physics Loss: 0.000422994 (0.000539345) 	 Data Loss: 0.000270824 (0.000311768) 	 BC Loss: 0.000087234 (0.000163823)
Iteration: [35501 / 50000] 	 Loss: 0.001209691 (0.001329915) 	 Physics Loss: 0.000784697 (0.000694873) 	 Data Loss: 0.000264404 (0.000425117) 	 BC Loss: 0.000160591 (0.000209925)
Iteration: [36001 / 50000] 	 Loss: 0.001398915 (0.001136703) 	 Physics Loss: 0.000490567 (0.000583094) 	 Data Loss: 0.000744279 (0.000401171) 	 BC Loss: 0.000164069 (0.000152438)
Iteration: [36501 / 50000] 	 Loss: 0.000835373 (0.001367094) 	 Physics Loss: 0.000595709 (0.000792289) 	 Data Loss: 0.000132676 (0.000426714) 	 BC Loss: 0.000106987 (0.000148092)
Iteration: [37001 / 50000] 	 Loss: 0.000952036 (0.001044580) 	 Physics Loss: 0.000709375 (0.000580873) 	 Data Loss: 0.000161545 (0.000334467) 	 BC Loss: 0.000081115 (0.000129240)
Iteration: [37501 / 50000] 	 Loss: 0.000579479 (0.001085730) 	 Physics Loss: 0.000270049 (0.000599332) 	 Data Loss: 0.000242183 (0.000348117) 	 BC Loss: 0.000067247 (0.000138281)
Iteration: [38001 / 50000] 	 Loss: 0.001053359 (0.001110448) 	 Physics Loss: 0.000445312 (0.000595116) 	 Data Loss: 0.000482833 (0.000353724) 	 BC Loss: 0.000125214 (0.000161607)
Iteration: [38501 / 50000] 	 Loss: 0.000667115 (0.001049175) 	 Physics Loss: 0.000326662 (0.000543516) 	 Data Loss: 0.000173648 (0.000387762) 	 BC Loss: 0.000166804 (0.000117897)
Iteration: [39001 / 50000] 	 Loss: 0.000687853 (0.001085206) 	 Physics Loss: 0.000313576 (0.000590442) 	 Data Loss: 0.000321921 (0.000367264) 	 BC Loss: 0.000052356 (0.000127500)
Iteration: [39501 / 50000] 	 Loss: 0.000673423 (0.001160439) 	 Physics Loss: 0.000421422 (0.000714949) 	 Data Loss: 0.000192454 (0.000300986) 	 BC Loss: 0.000059547 (0.000144504)
Iteration: [40001 / 50000] 	 Loss: 0.000938448 (0.001243982) 	 Physics Loss: 0.000547043 (0.000710499) 	 Data Loss: 0.000287677 (0.000374305) 	 BC Loss: 0.000103729 (0.000159178)
Iteration: [40501 / 50000] 	 Loss: 0.001798752 (0.001093546) 	 Physics Loss: 0.000866074 (0.000650020) 	 Data Loss: 0.000780730 (0.000305183) 	 BC Loss: 0.000151947 (0.000138344)
Iteration: [41001 / 50000] 	 Loss: 0.001101400 (0.001448896) 	 Physics Loss: 0.000600817 (0.000901216) 	 Data Loss: 0.000405329 (0.000389668) 	 BC Loss: 0.000095254 (0.000158011)
Iteration: [41501 / 50000] 	 Loss: 0.000776893 (0.000824789) 	 Physics Loss: 0.000455096 (0.000410463) 	 Data Loss: 0.000206055 (0.000317277) 	 BC Loss: 0.000115742 (0.000097049)
Iteration: [42001 / 50000] 	 Loss: 0.001060384 (0.001240770) 	 Physics Loss: 0.000583868 (0.000778712) 	 Data Loss: 0.000378937 (0.000335865) 	 BC Loss: 0.000097578 (0.000126193)
Iteration: [42501 / 50000] 	 Loss: 0.000691815 (0.000944194) 	 Physics Loss: 0.000381213 (0.000516831) 	 Data Loss: 0.000162085 (0.000307154) 	 BC Loss: 0.000148517 (0.000120209)
Iteration: [43001 / 50000] 	 Loss: 0.000545245 (0.000868335) 	 Physics Loss: 0.000363892 (0.000449657) 	 Data Loss: 0.000124517 (0.000324281) 	 BC Loss: 0.000056836 (0.000094398)
Iteration: [43501 / 50000] 	 Loss: 0.001068081 (0.001141232) 	 Physics Loss: 0.000613936 (0.000622923) 	 Data Loss: 0.000369106 (0.000392804) 	 BC Loss: 0.000085040 (0.000125504)
Iteration: [44001 / 50000] 	 Loss: 0.001168622 (0.000963809) 	 Physics Loss: 0.000486189 (0.000540564) 	 Data Loss: 0.000549527 (0.000306839) 	 BC Loss: 0.000132906 (0.000116407)
Iteration: [44501 / 50000] 	 Loss: 0.000939374 (0.000756769) 	 Physics Loss: 0.000379668 (0.000386392) 	 Data Loss: 0.000454265 (0.000277635) 	 BC Loss: 0.000105440 (0.000092742)
Iteration: [45001 / 50000] 	 Loss: 0.001202180 (0.000892342) 	 Physics Loss: 0.000726597 (0.000506370) 	 Data Loss: 0.000367492 (0.000280093) 	 BC Loss: 0.000108091 (0.000105879)
Iteration: [45501 / 50000] 	 Loss: 0.000887047 (0.000908168) 	 Physics Loss: 0.000641108 (0.000560744) 	 Data Loss: 0.000148591 (0.000263541) 	 BC Loss: 0.000097348 (0.000083883)
Iteration: [46001 / 50000] 	 Loss: 0.000525872 (0.000734937) 	 Physics Loss: 0.000322635 (0.000362759) 	 Data Loss: 0.000158824 (0.000288360) 	 BC Loss: 0.000044412 (0.000083818)
Iteration: [46501 / 50000] 	 Loss: 0.000545968 (0.000907963) 	 Physics Loss: 0.000350560 (0.000489585) 	 Data Loss: 0.000136768 (0.000312205) 	 BC Loss: 0.000058640 (0.000106173)
Iteration: [47001 / 50000] 	 Loss: 0.000826687 (0.000924496) 	 Physics Loss: 0.000575209 (0.000511645) 	 Data Loss: 0.000205284 (0.000314716) 	 BC Loss: 0.000046194 (0.000098134)
Iteration: [47501 / 50000] 	 Loss: 0.000620959 (0.001185708) 	 Physics Loss: 0.000382126 (0.000788119) 	 Data Loss: 0.000148993 (0.000271556) 	 BC Loss: 0.000089840 (0.000126034)
Iteration: [48001 / 50000] 	 Loss: 0.000664698 (0.000798627) 	 Physics Loss: 0.000412207 (0.000461797) 	 Data Loss: 0.000182339 (0.000250679) 	 BC Loss: 0.000070152 (0.000086152)
Iteration: [48501 / 50000] 	 Loss: 0.000826713 (0.000925605) 	 Physics Loss: 0.000485339 (0.000512687) 	 Data Loss: 0.000170909 (0.000314013) 	 BC Loss: 0.000170466 (0.000098905)
Iteration: [49001 / 50000] 	 Loss: 0.000559556 (0.000846881) 	 Physics Loss: 0.000350503 (0.000421408) 	 Data Loss: 0.000165743 (0.000346362) 	 BC Loss: 0.000043311 (0.000079111)
Iteration: [49501 / 50000] 	 Loss: 0.000652327 (0.000740137) 	 Physics Loss: 0.000301691 (0.000376061) 	 Data Loss: 0.000195276 (0.000275985) 	 BC Loss: 0.000155360 (0.000088090)

Visualizing the Results

julia
ts, xs, ys = 0.0f0:0.05f0:2.0f0, 0.0f0:0.02f0:2.0f0, 0.0f0:0.02f0:2.0f0
grid = stack([[elem...] for elem in vec(collect(Iterators.product(xs, ys, ts)))])

u_real = reshape(analytical_solution(grid), length(xs), length(ys), length(ts))

grid_normalized = (grid .- minimum(grid)) ./ (maximum(grid) .- minimum(grid))
u_pred = reshape(trained_u(grid_normalized), length(xs), length(ys), length(ts))
u_pred = u_pred .* (max_pde_val - min_pde_val) .+ min_pde_val

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
    errs = [abs.(u_pred[:, :, i] .- u_real[:, :, i]) for i in 1:length(ts)]
    Colorbar(fig[1, 2]; limits=extrema(stack(errs)))

    CairoMakie.record(fig, "pinn_nested_ad.gif", 1:length(ts); framerate=10) do i
        ax.title = "Abs. Predictor Error | Time: $(ts[i])"
        err = errs[i]
        contour!(ax, xs, ys, err; levels=10, linewidth=2)
        heatmap!(ax, xs, ys, err)
        return fig
    end

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.4, artifact installation
CUDA driver 12.5
NVIDIA driver 555.42.6

CUDA libraries: 
- CUBLAS: 12.4.5
- CURAND: 10.3.5
- CUFFT: 11.2.1
- CUSOLVER: 11.6.1
- CUSPARSE: 12.3.1
- CUPTI: 22.0.0
- NVML: 12.0.0+555.42.6

Julia packages: 
- CUDA: 5.3.3
- CUDA_Driver_jll: 0.8.1+0
- CUDA_Runtime_jll: 0.12.1+0

Toolchain:
- Julia: 1.10.5
- LLVM: 15.0.7

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

Preferences:
- CUDA_Driver_jll.compat: false

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.484 GiB / 4.750 GiB available)

This page was generated using Literate.jl.