Neural Networks Inside GPU Kernels
In this page, we will describe how to embed neural networks inside GPU kernels. We will use KernelAbstractions.jl to do this, making it compatible with multiple GPU backends.
Experimental Feature
This is a relatively new and experimental feature. Expect edge cases and open issues on GitHub if you find any.
Inference Only
Currently this works only for inference. We will eventually test automatic differentiation using Enzyme.jl
Batching
In most usecases, this form of batching via embedding the neural network inside a GPU kernel is not recommended and will lead to suboptimal performance. Instead, batch the input data and let Lux handle the batching internally.
using Lux, LuxCUDA, Random, Functors
using KernelAbstractions, StaticArrays
First thing to remember is that we can't use regular high-level operations inside the kernels, instead we will use Static Arrays. Leveraging Julia's multiple dispatch Lux will use specialized operations that are compatible with GPU kernels.
@kernel function nn_eval_single_batch!(output, model, input, ps, st)
i = @index(Global, Linear)
y, st_ = Lux.apply(model, input[i], ps, st)
output[i] = y
end
nn_eval_single_batch! (generic function with 4 methods)
We define and initialize the neural network as usual, but we need to additionally convert the Arrays into SArrays.
nn = Chain(Dense(4, 4, relu), Dense(4, 4))
ps, st = Lux.setup(Xoshiro(123), nn)
to_sarray(x) = SArray{Tuple{size(x)...}}(x)
ps_static = fmap(to_sarray, ps)
st_static = fmap(to_sarray, st)
(layer_1 = NamedTuple(), layer_2 = NamedTuple())
First we will run it on CPU.
Warning
Currently due to a minor bug, we cannot call the Lux models with vector input. As a workaround we make them into Matrix with batch size 1.
input = [@SArray(rand(Float64, 4, 1)) for i in 1:1024]
output = [@SArray(zeros(Float64, 4, 1)) for i in 1:1024] # Allocate the output
1024-element Vector{StaticArraysCore.SMatrix{4, 1, Float64, 4}}:
[0.0; 0.0; 0.0; 0.0;;]
[0.0; 0.0; 0.0; 0.0;;]
[0.0; 0.0; 0.0; 0.0;;]
[0.0; 0.0; 0.0; 0.0;;]
[0.0; 0.0; 0.0; 0.0;;]
[0.0; 0.0; 0.0; 0.0;;]
[0.0; 0.0; 0.0; 0.0;;]
[0.0; 0.0; 0.0; 0.0;;]
[0.0; 0.0; 0.0; 0.0;;]
[0.0; 0.0; 0.0; 0.0;;]
⋮
[0.0; 0.0; 0.0; 0.0;;]
[0.0; 0.0; 0.0; 0.0;;]
[0.0; 0.0; 0.0; 0.0;;]
[0.0; 0.0; 0.0; 0.0;;]
[0.0; 0.0; 0.0; 0.0;;]
[0.0; 0.0; 0.0; 0.0;;]
[0.0; 0.0; 0.0; 0.0;;]
[0.0; 0.0; 0.0; 0.0;;]
[0.0; 0.0; 0.0; 0.0;;]
Now run the model using KernelAbstractions.jl
backend = KernelAbstractions.get_backend(output)
cpu_kernel! = nn_eval_single_batch!(backend)
cpu_kernel!(output, nn, input, ps_static, st_static; ndrange=length(output))
KernelAbstractions.synchronize(backend)
output
1024-element Vector{StaticArraysCore.SMatrix{4, 1, Float64, 4}}:
[2.0564903986057956; 1.1188200246206075; -1.2227837233928576; -0.8173783982243132;;]
[1.9721554734769875; 1.3940224213371761; -1.297959481822617; -0.7195462169922175;;]
[2.5680085614623662; 1.713567516238075; -1.7165512278088038; -1.009963844931984;;]
[1.800792614736468; 0.36222499022985155; -1.1204217935313214; -1.1836515766351254;;]
[1.486550215883336; 0.32839986131789933; -0.9019142280758281; -0.9452923791531558;;]
[2.716134755899883; 1.1617228180412864; -1.902982902377702; -1.5865265807660498;;]
[1.0228109822209213; 0.2525357728685884; -0.4376572711003852; -0.4500963619011972;;]
[2.2771862617010155; 0.5381101016248151; -1.4730743722547668; -1.488028235902512;;]
[3.2791573282471997; 1.3436353225087703; -2.4619778701480337; -2.1239749674027375;;]
[1.2290224145974982; 0.4158693023143286; -0.6370531107315014; -0.5779067839062536;;]
⋮
[1.8674763752817416; 1.6423511984038721; -1.1477053709248992; -0.3834447782571344;;]
[2.091359335844565; 1.0621559246995447; -1.4763277207638008; -1.142470881033475;;]
[2.712979078066394; 0.42005835019799886; -1.717863343114228; -1.8601870861800127;;]
[0.7701346738750905; 0.2869913410456831; -0.1586047939092094; -0.10140238162746013;;]
[1.611584190904272; 1.2797048270773437; -0.923950547913545; -0.3558193508137715;;]
[2.0884834705765853; 0.862480861009647; -1.3942307655311696; -1.179584495291061;;]
[2.390200114697191; 0.5267549745189349; -1.657670184695808; -1.7089496198123055;;]
[2.1846486482317626; -0.031414255389526885; -1.3279041356366077; -1.6909446526419574;;]
[1.3650193059617517; 0.5210742834996898; -0.7689272356710357; -0.6642563709240284;;]
Now we will run the same model on GPU.
gdev = gpu_device()
input_gpu = input |> gdev
output_gpu = [@SArray(zeros(Float64, 4, 1)) for i in 1:1024] |> gdev
1024-element CuArray{StaticArraysCore.SMatrix{4, 1, Float64, 4}, 1, CUDA.DeviceMemory}:
[0.0; 0.0; 0.0; 0.0;;]
[0.0; 0.0; 0.0; 0.0;;]
[0.0; 0.0; 0.0; 0.0;;]
[0.0; 0.0; 0.0; 0.0;;]
[0.0; 0.0; 0.0; 0.0;;]
[0.0; 0.0; 0.0; 0.0;;]
[0.0; 0.0; 0.0; 0.0;;]
[0.0; 0.0; 0.0; 0.0;;]
[0.0; 0.0; 0.0; 0.0;;]
[0.0; 0.0; 0.0; 0.0;;]
⋮
[0.0; 0.0; 0.0; 0.0;;]
[0.0; 0.0; 0.0; 0.0;;]
[0.0; 0.0; 0.0; 0.0;;]
[0.0; 0.0; 0.0; 0.0;;]
[0.0; 0.0; 0.0; 0.0;;]
[0.0; 0.0; 0.0; 0.0;;]
[0.0; 0.0; 0.0; 0.0;;]
[0.0; 0.0; 0.0; 0.0;;]
[0.0; 0.0; 0.0; 0.0;;]
backend = KernelAbstractions.get_backend(output_gpu)
gpu_kernel! = nn_eval_single_batch!(backend)
gpu_kernel!(output_gpu, nn, input_gpu, ps_static, st_static; ndrange=length(output_gpu))
KernelAbstractions.synchronize(backend)
output_gpu
1024-element CuArray{StaticArraysCore.SMatrix{4, 1, Float64, 4}, 1, CUDA.DeviceMemory}:
[2.0564903986057956; 1.1188200246206075; -1.2227837233928576; -0.8173783982243132;;]
[1.9721554734769875; 1.3940224213371761; -1.297959481822617; -0.7195462169922173;;]
[2.5680085614623662; 1.713567516238075; -1.7165512278088038; -1.009963844931984;;]
[1.800792614736468; 0.36222499022985155; -1.1204217935313214; -1.1836515766351254;;]
[1.486550215883336; 0.32839986131789933; -0.9019142280758281; -0.9452923791531558;;]
[2.716134755899883; 1.1617228180412864; -1.902982902377702; -1.5865265807660498;;]
[1.0228109822209213; 0.2525357728685884; -0.4376572711003852; -0.4500963619011972;;]
[2.2771862617010155; 0.5381101016248151; -1.4730743722547668; -1.488028235902512;;]
[3.2791573282471997; 1.3436353225087703; -2.4619778701480337; -2.1239749674027375;;]
[1.2290224145974982; 0.4158693023143286; -0.6370531107315014; -0.5779067839062536;;]
⋮
[1.8674763752817414; 1.6423511984038721; -1.147705370924899; -0.3834447782571341;;]
[2.0913593358445652; 1.062155924699545; -1.4763277207638013; -1.142470881033475;;]
[2.712979078066394; 0.420058350197999; -1.717863343114228; -1.8601870861800127;;]
[0.7701346738750905; 0.2869913410456831; -0.1586047939092094; -0.10140238162746013;;]
[1.611584190904272; 1.2797048270773437; -0.923950547913545; -0.3558193508137715;;]
[2.0884834705765853; 0.862480861009647; -1.3942307655311696; -1.179584495291061;;]
[2.390200114697191; 0.5267549745189349; -1.657670184695808; -1.7089496198123055;;]
[2.1846486482317626; -0.031414255389526885; -1.3279041356366077; -1.6909446526419574;;]
[1.3650193059617517; 0.5210742834996898; -0.7689272356710357; -0.6642563709240284;;]