MNIST Classification using Neural ODEs
To understand Neural ODEs, users should look up these lecture notes. We recommend users to directly use DiffEqFlux.jl, instead of implementing Neural ODEs from scratch.
Package Imports
using Lux, ComponentArrays, SciMLSensitivity, AMDGPU, LuxCUDA, Optimisers, OrdinaryDiffEq,
Random, Statistics, Zygote, OneHotArrays, InteractiveUtils, Printf
import MLDatasets: MNIST
import MLUtils: DataLoader, splitobs
CUDA.allowscalar(false)
Loading MNIST
function loadmnist(batchsize, train_split)
# Load MNIST: Only 1500 for demonstration purposes
N = 1500
dataset = MNIST(; split=:train)
imgs = dataset.features[:, :, 1:N]
labels_raw = dataset.targets[1:N]
# Process images into (H,W,C,BS) batches
x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
y_data = onehotbatch(labels_raw, 0:9)
(x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true),
# Don't shuffle the test data
DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false))
end
loadmnist (generic function with 1 method)
Define the Neural ODE Layer
First we will use the @compact
macro to define the Neural ODE Layer.
function NeuralODECompact(
model::Lux.AbstractExplicitLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...)
return @compact(; model, solver, tspan, kwargs...) do x, p
dudt(u, p, t) = vec(model(reshape(u, size(x)), p))
# Note the `p.model` here
prob = ODEProblem(ODEFunction{false}(dudt), vec(x), tspan, p.model)
@return solve(prob, solver; kwargs...)
end
end
NeuralODECompact (generic function with 1 method)
We recommend using the compact macro for creating custom layers. The below implementation exists mostly for historical reasons when @compact
was not part of the stable API. Also, it helps users understand how the layer interface of Lux works.
The NeuralODE is a ContainerLayer, which stores a model
. The parameters and states of the NeuralODE are same as those of the underlying model.
struct NeuralODE{M <: Lux.AbstractExplicitLayer, So, T, K} <:
Lux.AbstractExplicitContainerLayer{(:model,)}
model::M
solver::So
tspan::T
kwargs::K
end
function NeuralODE(
model::Lux.AbstractExplicitLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...)
return NeuralODE(model, solver, tspan, kwargs)
end
Main.var"##225".NeuralODE
OrdinaryDiffEq.jl can deal with non-Vector Inputs! However, certain discrete sensitivities like ReverseDiffAdjoint
can't handle non-Vector inputs. Hence, we need to convert the input and output of the ODE solver to a Vector.
function (n::NeuralODE)(x, ps, st)
function dudt(u, p, t)
u_, st = n.model(reshape(u, size(x)), p, st)
return vec(u_)
end
prob = ODEProblem{false}(ODEFunction{false}(dudt), vec(x), n.tspan, ps)
return solve(prob, n.solver; n.kwargs...), st
end
@views diffeqsol_to_array(l::Int, x::ODESolution) = reshape(last(x.u), (l, :))
@views diffeqsol_to_array(l::Int, x::AbstractMatrix) = reshape(x[:, end], (l, :))
diffeqsol_to_array (generic function with 2 methods)
Create and Initialize the Neural ODE Layer
function create_model(model_fn=NeuralODE; dev=gpu_device(), use_named_tuple::Bool=false,
sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()))
# Construct the Neural ODE Model
model = Chain(FlattenLayer(),
Dense(784 => 20, tanh),
model_fn(Chain(Dense(20 => 10, tanh), Dense(10 => 10, tanh), Dense(10 => 20, tanh));
save_everystep=false, reltol=1.0f-3,
abstol=1.0f-3, save_start=false, sensealg),
Base.Fix1(diffeqsol_to_array, 20),
Dense(20 => 10))
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(rng, model)
ps = (use_named_tuple ? ps : ComponentArray(ps)) |> dev
st = st |> dev
return model, ps, st
end
create_model (generic function with 2 methods)
Define Utility Functions
const logitcrossentropy = CrossEntropyLoss(; logits=Val(true))
function accuracy(model, ps, st, dataloader; dev=gpu_device())
total_correct, total = 0, 0
st = Lux.testmode(st)
cpu_dev = cpu_device()
for (x, y) in dataloader
target_class = onecold(y)
predicted_class = onecold(cpu_dev(first(model(dev(x), ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
accuracy (generic function with 1 method)
Training
function train(model_function; cpu::Bool=false, kwargs...)
dev = cpu ? cpu_device() : gpu_device()
model, ps, st = create_model(model_function; dev, kwargs...)
# Training
train_dataloader, test_dataloader = loadmnist(128, 0.9)
tstate = Training.TrainState(model, ps, st, Adam(0.001f0))
### Lets train the model
nepochs = 9
for epoch in 1:nepochs
stime = time()
for (x, y) in train_dataloader
x = dev(x)
y = dev(y)
_, _, _, tstate = Training.single_train_step!(
AutoZygote(), logitcrossentropy, (x, y), tstate)
end
ttime = time() - stime
tr_acc = accuracy(model, tstate.parameters, tstate.states, train_dataloader; dev)
te_acc = accuracy(model, tstate.parameters, tstate.states, test_dataloader; dev)
@printf "[%d/%d] \t Time %.2fs \t Training Accuracy: %.5f%% \t Test \
Accuracy: %.5f%%\n" epoch nepochs ttime tr_acc te_acc
end
end
train(NeuralODECompact)
[1/9] Time 127.45s Training Accuracy: 0.50963% Test Accuracy: 0.43333%
[2/9] Time 0.39s Training Accuracy: 0.71333% Test Accuracy: 0.66000%
[3/9] Time 0.41s Training Accuracy: 0.77778% Test Accuracy: 0.70000%
[4/9] Time 0.39s Training Accuracy: 0.81111% Test Accuracy: 0.73333%
[5/9] Time 0.40s Training Accuracy: 0.82444% Test Accuracy: 0.77333%
[6/9] Time 0.41s Training Accuracy: 0.84519% Test Accuracy: 0.79333%
[7/9] Time 0.64s Training Accuracy: 0.85407% Test Accuracy: 0.80000%
[8/9] Time 0.42s Training Accuracy: 0.86815% Test Accuracy: 0.81333%
[9/9] Time 0.42s Training Accuracy: 0.87926% Test Accuracy: 0.82667%
train(NeuralODE)
[1/9] Time 36.21s Training Accuracy: 0.50963% Test Accuracy: 0.43333%
[2/9] Time 0.50s Training Accuracy: 0.71333% Test Accuracy: 0.66000%
[3/9] Time 0.37s Training Accuracy: 0.78000% Test Accuracy: 0.70000%
[4/9] Time 0.52s Training Accuracy: 0.81037% Test Accuracy: 0.74667%
[5/9] Time 0.41s Training Accuracy: 0.82889% Test Accuracy: 0.78000%
[6/9] Time 0.53s Training Accuracy: 0.84222% Test Accuracy: 0.79333%
[7/9] Time 0.41s Training Accuracy: 0.85926% Test Accuracy: 0.80000%
[8/9] Time 0.56s Training Accuracy: 0.86444% Test Accuracy: 0.82000%
[9/9] Time 0.41s Training Accuracy: 0.88000% Test Accuracy: 0.82667%
We can also change the sensealg and train the model! GaussAdjoint
allows you to use any arbitrary parameter structure and not just a flat vector (ComponentArray
).
train(NeuralODE; sensealg=GaussAdjoint(; autojacvec=ZygoteVJP()), use_named_tuple=true)
[1/9] Time 39.16s Training Accuracy: 0.50963% Test Accuracy: 0.43333%
[2/9] Time 0.35s Training Accuracy: 0.71111% Test Accuracy: 0.66000%
[3/9] Time 0.35s Training Accuracy: 0.78148% Test Accuracy: 0.70667%
[4/9] Time 0.37s Training Accuracy: 0.80667% Test Accuracy: 0.76000%
[5/9] Time 0.38s Training Accuracy: 0.82519% Test Accuracy: 0.78667%
[6/9] Time 0.57s Training Accuracy: 0.83778% Test Accuracy: 0.79333%
[7/9] Time 0.40s Training Accuracy: 0.85852% Test Accuracy: 0.80000%
[8/9] Time 0.39s Training Accuracy: 0.86593% Test Accuracy: 0.82000%
[9/9] Time 0.57s Training Accuracy: 0.88074% Test Accuracy: 0.82667%
But remember some AD backends like ReverseDiff
is not GPU compatible. For a model this size, you will notice that training time is significantly lower for training on CPU than on GPU.
train(NeuralODE; sensealg=InterpolatingAdjoint(; autojacvec=ReverseDiffVJP()), cpu=true)
[1/9] Time 103.66s Training Accuracy: 0.50963% Test Accuracy: 0.43333%
[2/9] Time 18.40s Training Accuracy: 0.69630% Test Accuracy: 0.66000%
[3/9] Time 21.40s Training Accuracy: 0.77926% Test Accuracy: 0.71333%
[4/9] Time 22.59s Training Accuracy: 0.80741% Test Accuracy: 0.76667%
[5/9] Time 23.41s Training Accuracy: 0.82519% Test Accuracy: 0.78000%
[6/9] Time 23.42s Training Accuracy: 0.84074% Test Accuracy: 0.78667%
[7/9] Time 21.43s Training Accuracy: 0.85333% Test Accuracy: 0.80667%
[8/9] Time 22.16s Training Accuracy: 0.86593% Test Accuracy: 0.81333%
[9/9] Time 19.78s Training Accuracy: 0.87704% Test Accuracy: 0.82000%
For completeness, let's also test out discrete sensitivities!
train(NeuralODE; sensealg=ReverseDiffAdjoint(), cpu=true)
[1/9] Time 38.85s Training Accuracy: 0.50963% Test Accuracy: 0.43333%
[2/9] Time 23.75s Training Accuracy: 0.69630% Test Accuracy: 0.66000%
[3/9] Time 20.51s Training Accuracy: 0.77926% Test Accuracy: 0.71333%
[4/9] Time 22.30s Training Accuracy: 0.80741% Test Accuracy: 0.76667%
[5/9] Time 26.19s Training Accuracy: 0.82519% Test Accuracy: 0.78000%
[6/9] Time 23.35s Training Accuracy: 0.84074% Test Accuracy: 0.78667%
[7/9] Time 26.31s Training Accuracy: 0.85333% Test Accuracy: 0.80667%
[8/9] Time 26.00s Training Accuracy: 0.86593% Test Accuracy: 0.81333%
[9/9] Time 26.28s Training Accuracy: 0.87704% Test Accuracy: 0.82000%
Alternate Implementation using Stateful Layer
Starting v0.5.5
, Lux provides a StatefulLuxLayer
which can be used to avoid the Box
ing of st
. Using the @compact
API avoids this problem entirely.
struct StatefulNeuralODE{M <: Lux.AbstractExplicitLayer, So, T, K} <:
Lux.AbstractExplicitContainerLayer{(:model,)}
model::M
solver::So
tspan::T
kwargs::K
end
function StatefulNeuralODE(
model::Lux.AbstractExplicitLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...)
return StatefulNeuralODE(model, solver, tspan, kwargs)
end
function (n::StatefulNeuralODE)(x, ps, st)
st_model = StatefulLuxLayer(n.model, ps, st)
dudt(u, p, t) = st_model(u, p)
prob = ODEProblem{false}(ODEFunction{false}(dudt), x, n.tspan, ps)
return solve(prob, n.solver; n.kwargs...), st_model.st
end
Train the new Stateful Neural ODE
train(StatefulNeuralODE)
[1/9] Time 36.53s Training Accuracy: 0.50963% Test Accuracy: 0.43333%
[2/9] Time 0.35s Training Accuracy: 0.71630% Test Accuracy: 0.66667%
[3/9] Time 0.35s Training Accuracy: 0.77926% Test Accuracy: 0.70000%
[4/9] Time 0.38s Training Accuracy: 0.80370% Test Accuracy: 0.74667%
[5/9] Time 0.61s Training Accuracy: 0.82741% Test Accuracy: 0.77333%
[6/9] Time 0.38s Training Accuracy: 0.84148% Test Accuracy: 0.79333%
[7/9] Time 0.39s Training Accuracy: 0.85407% Test Accuracy: 0.80000%
[8/9] Time 0.39s Training Accuracy: 0.86667% Test Accuracy: 0.80667%
[9/9] Time 0.40s Training Accuracy: 0.88000% Test Accuracy: 0.82000%
We might not see a significant difference in the training time, but let us investigate the type stabilities of the layers.
Type Stability
model, ps, st = create_model(NeuralODE)
model_stateful, ps_stateful, st_stateful = create_model(StatefulNeuralODE)
x = gpu_device()(ones(Float32, 28, 28, 1, 3));
NeuralODE is not type stable due to the boxing of st
@code_warntype model(x, ps, st)
MethodInstance for (::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}, layer_3::Main.var"##225".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}, layer_2::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}, layer_3::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{:direct_call, Base.Fix1{typeof(Main.var"##225".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}}, Nothing})(::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}, ::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, ShapedAxis((20, 1))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, ShapedAxis((10, 1))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, ShapedAxis((10, 1))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, ShapedAxis((20, 1))))))), layer_4 = 16241:16240, layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, ShapedAxis((10, 1))))))}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}})
from (c::Lux.Chain)(x, ps, st::NamedTuple) @ Lux /var/lib/buildkite-agent/builds/gpuci-3/julialang/lux-dot-jl/src/layers/containers.jl:518
Arguments
c::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}, layer_3::Main.var"##225".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}, layer_2::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}, layer_3::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{:direct_call, Base.Fix1{typeof(Main.var"##225".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}}, Nothing}
x::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}
ps::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, ShapedAxis((20, 1))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, ShapedAxis((10, 1))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, ShapedAxis((10, 1))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, ShapedAxis((20, 1))))))), layer_4 = 16241:16240, layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, ShapedAxis((10, 1))))))}}}
st::Core.Const((layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), layer_4 = NamedTuple(), layer_5 = NamedTuple()))
Body::TUPLE{CUDA.CUARRAY{FLOAT32, 2, CUDA.DEVICEMEMORY}, NAMEDTUPLE{(:LAYER_1, :LAYER_2, :LAYER_3, :LAYER_4, :LAYER_5), <:TUPLE{@NAMEDTUPLE{}, @NAMEDTUPLE{}, ANY, @NAMEDTUPLE{}, @NAMEDTUPLE{}}}}
1 ─ %1 = Base.getproperty(c, :layers)::@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}, layer_3::Main.var"##225".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}, layer_2::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}, layer_3::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{:direct_call, Base.Fix1{typeof(Main.var"##225".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}}
│ %2 = Lux.applychain(%1, x, ps, st)::TUPLE{CUDA.CUARRAY{FLOAT32, 2, CUDA.DEVICEMEMORY}, NAMEDTUPLE{(:LAYER_1, :LAYER_2, :LAYER_3, :LAYER_4, :LAYER_5), <:TUPLE{@NAMEDTUPLE{}, @NAMEDTUPLE{}, ANY, @NAMEDTUPLE{}, @NAMEDTUPLE{}}}}
└── return %2
We avoid the problem entirely by using StatefulNeuralODE
@code_warntype model_stateful(x, ps_stateful, st_stateful)
MethodInstance for (::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}, layer_3::Main.var"##225".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}, layer_2::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}, layer_3::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{:direct_call, Base.Fix1{typeof(Main.var"##225".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}}, Nothing})(::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}, ::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, ShapedAxis((20, 1))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, ShapedAxis((10, 1))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, ShapedAxis((10, 1))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, ShapedAxis((20, 1))))))), layer_4 = 16241:16240, layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, ShapedAxis((10, 1))))))}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}})
from (c::Lux.Chain)(x, ps, st::NamedTuple) @ Lux /var/lib/buildkite-agent/builds/gpuci-3/julialang/lux-dot-jl/src/layers/containers.jl:518
Arguments
c::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}, layer_3::Main.var"##225".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}, layer_2::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}, layer_3::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{:direct_call, Base.Fix1{typeof(Main.var"##225".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}}, Nothing}
x::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}
ps::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, ShapedAxis((20, 1))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, ShapedAxis((10, 1))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, ShapedAxis((10, 1))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, ShapedAxis((20, 1))))))), layer_4 = 16241:16240, layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, ShapedAxis((10, 1))))))}}}
st::Core.Const((layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), layer_4 = NamedTuple(), layer_5 = NamedTuple()))
Body::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
1 ─ %1 = Base.getproperty(c, :layers)::@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}, layer_3::Main.var"##225".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}, layer_2::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}, layer_3::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{:direct_call, Base.Fix1{typeof(Main.var"##225".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}}
│ %2 = Lux.applychain(%1, x, ps, st)::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
└── return %2
Note, that we still recommend using this layer internally and not exposing this as the default API to the users.
Finally checking the compact model
model_compact, ps_compact, st_compact = create_model(NeuralODECompact)
@code_warntype model_compact(x, ps_compact, st_compact)
MethodInstance for (::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}, layer_3::Lux.CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, Main.var"##225".var"#2#3", Nothing, @NamedTuple{model::Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}, layer_2::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}, layer_3::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}}, Nothing}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{solver::Returns{OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}}, tspan::Returns{Tuple{Float32, Float32}}}}, Tuple{Tuple{Symbol}, Tuple{Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::Lux.WrappedFunction{:direct_call, Base.Fix1{typeof(Main.var"##225".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}}, Nothing})(::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}, ::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, ShapedAxis((20, 1))))), layer_3 = ViewAxis(15701:16240, Axis(model = ViewAxis(1:540, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, ShapedAxis((10, 1))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, ShapedAxis((10, 1))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, ShapedAxis((20, 1))))))),)), layer_4 = 16241:16240, layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, ShapedAxis((10, 1))))))}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}})
from (c::Lux.Chain)(x, ps, st::NamedTuple) @ Lux /var/lib/buildkite-agent/builds/gpuci-3/julialang/lux-dot-jl/src/layers/containers.jl:518
Arguments
c::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}, layer_3::Lux.CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, Main.var"##225".var"#2#3", Nothing, @NamedTuple{model::Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}, layer_2::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}, layer_3::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}}, Nothing}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{solver::Returns{OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}}, tspan::Returns{Tuple{Float32, Float32}}}}, Tuple{Tuple{Symbol}, Tuple{Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::Lux.WrappedFunction{:direct_call, Base.Fix1{typeof(Main.var"##225".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}}, Nothing}
x::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}
ps::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, ShapedAxis((20, 1))))), layer_3 = ViewAxis(15701:16240, Axis(model = ViewAxis(1:540, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, ShapedAxis((10, 1))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, ShapedAxis((10, 1))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, ShapedAxis((20, 1))))))),)), layer_4 = 16241:16240, layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, ShapedAxis((10, 1))))))}}}
st::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}
Body::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
1 ─ %1 = Base.getproperty(c, :layers)::@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}, layer_3::Lux.CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, Main.var"##225".var"#2#3", Nothing, @NamedTuple{model::Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}, layer_2::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}, layer_3::Lux.Dense{typeof(NNlib.tanh_fast), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}}, Nothing}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{solver::Returns{OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}}, tspan::Returns{Tuple{Float32, Float32}}}}, Tuple{Tuple{Symbol}, Tuple{Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::Lux.WrappedFunction{:direct_call, Base.Fix1{typeof(Main.var"##225".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32), Static.True}}
│ %2 = Lux.applychain(%1, x, ps, st)::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
└── return %2
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(LuxDeviceUtils)
if @isdefined(CUDA) && LuxDeviceUtils.functional(LuxCUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && LuxDeviceUtils.functional(LuxAMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 4 default, 0 interactive, 2 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 4
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
CUDA runtime 12.5, artifact installation
CUDA driver 12.5
NVIDIA driver 555.42.6
CUDA libraries:
- CUBLAS: 12.5.3
- CURAND: 10.3.6
- CUFFT: 11.2.3
- CUSOLVER: 11.6.3
- CUSPARSE: 12.5.1
- CUPTI: 2024.2.1 (API 23.0.0)
- NVML: 12.0.0+555.42.6
Julia packages:
- CUDA: 5.4.3
- CUDA_Driver_jll: 0.9.2+0
- CUDA_Runtime_jll: 0.14.1+0
Toolchain:
- Julia: 1.10.5
- LLVM: 15.0.7
Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%
1 device:
0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 3.951 GiB / 4.750 GiB available)
This page was generated using Literate.jl.