Skip to content

Neural Networks Inside GPU Kernels

In this page, we will describe how to embed neural networks inside GPU kernels. We will use KernelAbstractions.jl to do this, making it compatible with multiple GPU backends.

Experimental Feature

This is a relatively new and experimental feature. Expect edge cases and open issues on GitHub if you find any.

Inference Only

Currently this works only for inference. We will eventually test automatic differentiation using Enzyme.jl

Batching

In most usecases, this form of batching via embedding the neural network inside a GPU kernel is not recommended and will lead to suboptimal performance. Instead, batch the input data and let Lux handle the batching internally.

julia
using Lux, LuxCUDA, Random, Functors
using KernelAbstractions, StaticArrays

First thing to remember is that we can't use regular high-level operations inside the kernels, instead we will use Static Arrays. Leveraging Julia's multiple dispatch Lux will use specialized operations that are compatible with GPU kernels.

julia
@kernel function nn_eval_single_batch!(output, model, input, ps, st)
    i = @index(Global, Linear)
    y, st_ = Lux.apply(model, input[i], ps, st)
    output[i] = y
end
nn_eval_single_batch! (generic function with 4 methods)

We define and initialize the neural network as usual, but we need to additionally convert the Arrays into SArrays.

julia
nn = Chain(Dense(4, 4, relu), Dense(4, 4))
ps, st = Lux.setup(Xoshiro(123), nn)

to_sarray(x) = SArray{Tuple{size(x)...}}(x)
ps_static = fmap(to_sarray, ps)
st_static = fmap(to_sarray, st)
(layer_1 = NamedTuple(), layer_2 = NamedTuple())

First we will run it on CPU.

Warning

Currently due to a minor bug, we cannot call the Lux models with vector input. As a workaround we make them into Matrix with batch size 1.

julia
input = [@SArray(rand(Float64, 4, 1)) for i in 1:1024]
output = [@SArray(zeros(Float64, 4, 1)) for i in 1:1024] # Allocate the output
1024-element Vector{StaticArraysCore.SMatrix{4, 1, Float64, 4}}:
 [0.0; 0.0; 0.0; 0.0;;]
 [0.0; 0.0; 0.0; 0.0;;]
 [0.0; 0.0; 0.0; 0.0;;]
 [0.0; 0.0; 0.0; 0.0;;]
 [0.0; 0.0; 0.0; 0.0;;]
 [0.0; 0.0; 0.0; 0.0;;]
 [0.0; 0.0; 0.0; 0.0;;]
 [0.0; 0.0; 0.0; 0.0;;]
 [0.0; 0.0; 0.0; 0.0;;]
 [0.0; 0.0; 0.0; 0.0;;]

 [0.0; 0.0; 0.0; 0.0;;]
 [0.0; 0.0; 0.0; 0.0;;]
 [0.0; 0.0; 0.0; 0.0;;]
 [0.0; 0.0; 0.0; 0.0;;]
 [0.0; 0.0; 0.0; 0.0;;]
 [0.0; 0.0; 0.0; 0.0;;]
 [0.0; 0.0; 0.0; 0.0;;]
 [0.0; 0.0; 0.0; 0.0;;]
 [0.0; 0.0; 0.0; 0.0;;]

Now run the model using KernelAbstractions.jl

julia
backend = KernelAbstractions.get_backend(output)
cpu_kernel! = nn_eval_single_batch!(backend)
cpu_kernel!(output, nn, input, ps_static, st_static; ndrange=length(output))
KernelAbstractions.synchronize(backend)
output
1024-element Vector{StaticArraysCore.SMatrix{4, 1, Float64, 4}}:
 [2.0564903986057956; 1.1188200246206075; -1.2227837233928576; -0.8173783982243132;;]
 [1.9721554734769875; 1.3940224213371761; -1.297959481822617; -0.7195462169922175;;]
 [2.5680085614623662; 1.713567516238075; -1.7165512278088038; -1.009963844931984;;]
 [1.800792614736468; 0.36222499022985155; -1.1204217935313214; -1.1836515766351254;;]
 [1.486550215883336; 0.32839986131789933; -0.9019142280758281; -0.9452923791531558;;]
 [2.716134755899883; 1.1617228180412864; -1.902982902377702; -1.5865265807660498;;]
 [1.0228109822209213; 0.2525357728685884; -0.4376572711003852; -0.4500963619011972;;]
 [2.2771862617010155; 0.5381101016248151; -1.4730743722547668; -1.488028235902512;;]
 [3.2791573282471997; 1.3436353225087703; -2.4619778701480337; -2.1239749674027375;;]
 [1.2290224145974982; 0.4158693023143286; -0.6370531107315014; -0.5779067839062536;;]

 [1.8674763752817416; 1.6423511984038721; -1.1477053709248992; -0.3834447782571344;;]
 [2.091359335844565; 1.0621559246995447; -1.4763277207638008; -1.142470881033475;;]
 [2.712979078066394; 0.42005835019799886; -1.717863343114228; -1.8601870861800127;;]
 [0.7701346738750905; 0.2869913410456831; -0.1586047939092094; -0.10140238162746013;;]
 [1.611584190904272; 1.2797048270773437; -0.923950547913545; -0.3558193508137715;;]
 [2.0884834705765853; 0.862480861009647; -1.3942307655311696; -1.179584495291061;;]
 [2.390200114697191; 0.5267549745189349; -1.657670184695808; -1.7089496198123055;;]
 [2.1846486482317626; -0.031414255389526885; -1.3279041356366077; -1.6909446526419574;;]
 [1.3650193059617517; 0.5210742834996898; -0.7689272356710357; -0.6642563709240284;;]

Now we will run the same model on GPU.

julia
gdev = gpu_device()

input_gpu = input |> gdev
output_gpu = [@SArray(zeros(Float64, 4, 1)) for i in 1:1024] |> gdev
1024-element CuArray{StaticArraysCore.SMatrix{4, 1, Float64, 4}, 1, CUDA.DeviceMemory}:
 [0.0; 0.0; 0.0; 0.0;;]
 [0.0; 0.0; 0.0; 0.0;;]
 [0.0; 0.0; 0.0; 0.0;;]
 [0.0; 0.0; 0.0; 0.0;;]
 [0.0; 0.0; 0.0; 0.0;;]
 [0.0; 0.0; 0.0; 0.0;;]
 [0.0; 0.0; 0.0; 0.0;;]
 [0.0; 0.0; 0.0; 0.0;;]
 [0.0; 0.0; 0.0; 0.0;;]
 [0.0; 0.0; 0.0; 0.0;;]

 [0.0; 0.0; 0.0; 0.0;;]
 [0.0; 0.0; 0.0; 0.0;;]
 [0.0; 0.0; 0.0; 0.0;;]
 [0.0; 0.0; 0.0; 0.0;;]
 [0.0; 0.0; 0.0; 0.0;;]
 [0.0; 0.0; 0.0; 0.0;;]
 [0.0; 0.0; 0.0; 0.0;;]
 [0.0; 0.0; 0.0; 0.0;;]
 [0.0; 0.0; 0.0; 0.0;;]
julia
backend = KernelAbstractions.get_backend(output_gpu)
gpu_kernel! = nn_eval_single_batch!(backend)
gpu_kernel!(output_gpu, nn, input_gpu, ps_static, st_static; ndrange=length(output_gpu))
KernelAbstractions.synchronize(backend)
output_gpu
1024-element CuArray{StaticArraysCore.SMatrix{4, 1, Float64, 4}, 1, CUDA.DeviceMemory}:
 [2.0564903986057956; 1.1188200246206075; -1.2227837233928576; -0.8173783982243132;;]
 [1.9721554734769875; 1.3940224213371761; -1.297959481822617; -0.7195462169922173;;]
 [2.5680085614623662; 1.713567516238075; -1.7165512278088038; -1.009963844931984;;]
 [1.800792614736468; 0.36222499022985155; -1.1204217935313214; -1.1836515766351254;;]
 [1.486550215883336; 0.32839986131789933; -0.9019142280758281; -0.9452923791531558;;]
 [2.716134755899883; 1.1617228180412864; -1.902982902377702; -1.5865265807660498;;]
 [1.0228109822209213; 0.2525357728685884; -0.4376572711003852; -0.4500963619011972;;]
 [2.2771862617010155; 0.5381101016248151; -1.4730743722547668; -1.488028235902512;;]
 [3.2791573282471997; 1.3436353225087703; -2.4619778701480337; -2.1239749674027375;;]
 [1.2290224145974982; 0.4158693023143286; -0.6370531107315014; -0.5779067839062536;;]

 [1.8674763752817414; 1.6423511984038721; -1.147705370924899; -0.3834447782571341;;]
 [2.0913593358445652; 1.062155924699545; -1.4763277207638013; -1.142470881033475;;]
 [2.712979078066394; 0.420058350197999; -1.717863343114228; -1.8601870861800127;;]
 [0.7701346738750905; 0.2869913410456831; -0.1586047939092094; -0.10140238162746013;;]
 [1.611584190904272; 1.2797048270773437; -0.923950547913545; -0.3558193508137715;;]
 [2.0884834705765853; 0.862480861009647; -1.3942307655311696; -1.179584495291061;;]
 [2.390200114697191; 0.5267549745189349; -1.657670184695808; -1.7089496198123055;;]
 [2.1846486482317626; -0.031414255389526885; -1.3279041356366077; -1.6909446526419574;;]
 [1.3650193059617517; 0.5210742834996898; -0.7689272356710357; -0.6642563709240284;;]