Training a Simple LSTM
In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:
Create custom Lux models.
Become familiar with the Lux recurrent neural network API.
Training using Optimisers.jl and Zygote.jl.
Package Imports
Note: If you wish to use AutoZygote()
for automatic differentiation, add Zygote to your project dependencies and include using Zygote
.
using ADTypes, Lux, JLD2, MLUtils, Optimisers, Printf, Reactant, Random
Precompiling Reactant...
13173.7 ms ? Enzyme
13627.6 ms ? Enzyme → EnzymeGPUArraysCoreExt
Info Given Reactant was explicitly requested, output will be shown live [0K
[0KWARNING: Method definition within_autodiff() in module EnzymeCore at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/EnzymeCore/0ptb3/src/EnzymeCore.jl:619 overwritten in module Enzyme at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Enzyme/nqe7m/src/Enzyme.jl:1561.
[0KERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
14897.2 ms ? Reactant
WARNING: Method definition within_autodiff() in module EnzymeCore at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/EnzymeCore/0ptb3/src/EnzymeCore.jl:619 overwritten in module Enzyme at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Enzyme/nqe7m/src/Enzyme.jl:1561.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
Precompiling Enzyme...
Info Given Enzyme was explicitly requested, output will be shown live [0K
[0KWARNING: Method definition within_autodiff() in module EnzymeCore at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/EnzymeCore/0ptb3/src/EnzymeCore.jl:619 overwritten in module Enzyme at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Enzyme/nqe7m/src/Enzyme.jl:1561.
[0KERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
13239.6 ms ? Enzyme
WARNING: Method definition within_autodiff() in module EnzymeCore at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/EnzymeCore/0ptb3/src/EnzymeCore.jl:619 overwritten in module Enzyme at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Enzyme/nqe7m/src/Enzyme.jl:1561.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
Precompiling LuxEnzymeExt...
13477.6 ms ? Enzyme
712.3 ms ? Enzyme → EnzymeChainRulesCoreExt
864.5 ms ? Enzyme → EnzymeSpecialFunctionsExt
899.6 ms ? Enzyme → EnzymeStaticArraysExt
701.6 ms ? Enzyme → EnzymeLogExpFunctionsExt
Info Given LuxEnzymeExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Enzyme with build ID ffffffff-ffff-ffff-2d10-7b61019616ef is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
729.4 ms ? Enzyme → EnzymeGPUArraysCoreExt
710.8 ms ? Lux → LuxEnzymeExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Enzyme with build ID ffffffff-ffff-ffff-2d10-7b61019616ef is missing from the cache.
[33m[1m│ [22m[39mThis may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling EnzymeLogExpFunctionsExt...
13046.4 ms ? Enzyme
Info Given EnzymeLogExpFunctionsExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Enzyme with build ID ffffffff-ffff-ffff-2d10-7b61019616ef is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
704.4 ms ? Enzyme → EnzymeLogExpFunctionsExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Enzyme with build ID ffffffff-ffff-ffff-2d10-7b61019616ef is missing from the cache.
[33m[1m│ [22m[39mThis may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling EnzymeChainRulesCoreExt...
13284.7 ms ? Enzyme
Info Given EnzymeChainRulesCoreExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Enzyme with build ID ffffffff-ffff-ffff-2d10-7b61019616ef is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
699.9 ms ? Enzyme → EnzymeChainRulesCoreExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Enzyme with build ID ffffffff-ffff-ffff-2d10-7b61019616ef is missing from the cache.
[33m[1m│ [22m[39mThis may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling EnzymeSpecialFunctionsExt...
12960.4 ms ? Enzyme
732.1 ms ? Enzyme → EnzymeLogExpFunctionsExt
Info Given EnzymeSpecialFunctionsExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Enzyme with build ID ffffffff-ffff-ffff-2d10-7b61019616ef is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
861.9 ms ? Enzyme → EnzymeSpecialFunctionsExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Enzyme with build ID ffffffff-ffff-ffff-2d10-7b61019616ef is missing from the cache.
[33m[1m│ [22m[39mThis may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling EnzymeGPUArraysCoreExt...
13060.0 ms ? Enzyme
Info Given EnzymeGPUArraysCoreExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Enzyme with build ID ffffffff-ffff-ffff-2d10-7b61019616ef is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
685.5 ms ? Enzyme → EnzymeGPUArraysCoreExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Enzyme with build ID ffffffff-ffff-ffff-2d10-7b61019616ef is missing from the cache.
[33m[1m│ [22m[39mThis may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling EnzymeStaticArraysExt...
13450.1 ms ? Enzyme
Info Given EnzymeStaticArraysExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Enzyme with build ID ffffffff-ffff-ffff-2d10-7b61019616ef is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
892.5 ms ? Enzyme → EnzymeStaticArraysExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Enzyme with build ID ffffffff-ffff-ffff-2d10-7b61019616ef is missing from the cache.
[33m[1m│ [22m[39mThis may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling OptimisersReactantExt...
13399.3 ms ? Enzyme
703.5 ms ? Enzyme → EnzymeGPUArraysCoreExt
721.8 ms ? Enzyme → EnzymeChainRulesCoreExt
1963.7 ms ? Reactant
687.7 ms ? Reactant → ReactantStatisticsExt
Info Given OptimisersReactantExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
726.1 ms ? Optimisers → OptimisersReactantExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling LuxCoreReactantExt...
13065.9 ms ? Enzyme
705.8 ms ? Enzyme → EnzymeGPUArraysCoreExt
1951.9 ms ? Reactant
Info Given LuxCoreReactantExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
766.4 ms ? LuxCore → LuxCoreReactantExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling MLDataDevicesReactantExt...
13113.1 ms ? Enzyme
683.8 ms ? Enzyme → EnzymeGPUArraysCoreExt
1879.9 ms ? Reactant
Info Given MLDataDevicesReactantExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
727.2 ms ? MLDataDevices → MLDataDevicesReactantExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling WeightInitializersReactantExt...
13172.6 ms ? Enzyme
723.1 ms ? Enzyme → EnzymeLogExpFunctionsExt
721.2 ms ? Enzyme → EnzymeGPUArraysCoreExt
923.3 ms ? Enzyme → EnzymeSpecialFunctionsExt
1982.8 ms ? Reactant
Info Given WeightInitializersReactantExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
759.0 ms ? Reactant → ReactantStatisticsExt
754.3 ms ? WeightInitializers → WeightInitializersReactantExt
906.1 ms ? Reactant → ReactantSpecialFunctionsExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling ReactantKernelAbstractionsExt...
13196.5 ms ? Enzyme
712.9 ms ? Enzyme → EnzymeGPUArraysCoreExt
889.3 ms ? Enzyme → EnzymeStaticArraysExt
1901.9 ms ? Reactant
Info Given ReactantKernelAbstractionsExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
684.2 ms ? Reactant → ReactantKernelAbstractionsExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling ReactantArrayInterfaceExt...
13729.1 ms ? Enzyme
716.1 ms ? Enzyme → EnzymeGPUArraysCoreExt
2038.0 ms ? Reactant
Info Given ReactantArrayInterfaceExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
687.8 ms ? Reactant → ReactantArrayInterfaceExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling ReactantSpecialFunctionsExt...
13103.2 ms ? Enzyme
706.7 ms ? Enzyme → EnzymeGPUArraysCoreExt
729.9 ms ? Enzyme → EnzymeLogExpFunctionsExt
865.4 ms ? Enzyme → EnzymeSpecialFunctionsExt
2052.4 ms ? Reactant
Info Given ReactantSpecialFunctionsExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
891.3 ms ? Reactant → ReactantSpecialFunctionsExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling ReactantStatisticsExt...
13680.2 ms ? Enzyme
693.9 ms ? Enzyme → EnzymeGPUArraysCoreExt
2044.9 ms ? Reactant
Info Given ReactantStatisticsExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
732.6 ms ? Reactant → ReactantStatisticsExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling LuxLibReactantExt...
13491.7 ms ? Enzyme
726.1 ms ? Enzyme → EnzymeChainRulesCoreExt
836.2 ms ? Enzyme → EnzymeSpecialFunctionsExt
914.6 ms ? Enzyme → EnzymeStaticArraysExt
714.4 ms ? Enzyme → EnzymeLogExpFunctionsExt
685.8 ms ? Enzyme → EnzymeGPUArraysCoreExt
1891.8 ms ? Reactant
688.0 ms ? Reactant → ReactantStatisticsExt
740.3 ms ? Reactant → ReactantKernelAbstractionsExt
887.3 ms ? Reactant → ReactantSpecialFunctionsExt
748.1 ms ? Reactant → ReactantArrayInterfaceExt
801.9 ms ? MLDataDevices → MLDataDevicesReactantExt
751.1 ms ? LuxCore → LuxCoreReactantExt
Info Given LuxLibReactantExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
704.9 ms ? LuxLib → LuxLibReactantExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling ReactantNNlibExt...
13380.0 ms ? Enzyme
732.5 ms ? Enzyme → EnzymeGPUArraysCoreExt
772.3 ms ? Enzyme → EnzymeChainRulesCoreExt
925.9 ms ? Enzyme → EnzymeStaticArraysExt
2102.9 ms ? Reactant
705.6 ms ? Reactant → ReactantStatisticsExt
732.6 ms ? Reactant → ReactantKernelAbstractionsExt
Info Given ReactantNNlibExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
1024.3 ms ? Reactant → ReactantNNlibExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
[33m[1m│ [22m[39mThis may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Precompiling LuxReactantExt...
13342.6 ms ? Enzyme
754.4 ms ? Enzyme → EnzymeChainRulesCoreExt
862.0 ms ? Enzyme → EnzymeSpecialFunctionsExt
917.2 ms ? Enzyme → EnzymeStaticArraysExt
742.6 ms ? Enzyme → EnzymeLogExpFunctionsExt
733.6 ms ? Enzyme → EnzymeGPUArraysCoreExt
714.3 ms ? Lux → LuxEnzymeExt
2103.5 ms ? Reactant
697.1 ms ? Reactant → ReactantStatisticsExt
717.3 ms ? Reactant → ReactantKernelAbstractionsExt
890.5 ms ? Reactant → ReactantSpecialFunctionsExt
735.8 ms ? Reactant → ReactantArrayInterfaceExt
770.7 ms ? MLDataDevices → MLDataDevicesReactantExt
749.2 ms ? LuxCore → LuxCoreReactantExt
760.8 ms ? Optimisers → OptimisersReactantExt
720.0 ms ? WeightInitializers → WeightInitializersReactantExt
718.5 ms ? LuxLib → LuxLibReactantExt
Info Given LuxReactantExt was explicitly requested, output will be shown live [0K
[0K[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Enzyme with build ID ffffffff-ffff-ffff-2d10-7b61019616ef is missing from the cache.
[0K[33m[1m│ [22m[39mThis may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
[0K[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
728.8 ms ? Lux → LuxReactantExt
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mModule Enzyme with build ID ffffffff-ffff-ffff-2d10-7b61019616ef is missing from the cache.
[33m[1m│ [22m[39mThis may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
[33m[1m└ [22m[39m[90m@ Base loading.jl:2541[39m
Dataset
We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader
. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.
function create_dataset(; dataset_size=1000, sequence_length=50)
# Create the spirals
data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
# Get the labels
labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
clockwise_spirals = [
reshape(d[1][:, 1:sequence_length], :, sequence_length, 1) for
d in data[1:(dataset_size ÷ 2)]
]
anticlockwise_spirals = [
reshape(d[1][:, (sequence_length + 1):end], :, sequence_length, 1) for
d in data[((dataset_size ÷ 2) + 1):end]
]
x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
return x_data, labels
end
function get_dataloaders(; dataset_size=1000, sequence_length=50)
x_data, labels = create_dataset(; dataset_size, sequence_length)
# Split the dataset
(x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
# Create DataLoaders
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(
collect.((x_train, y_train)); batchsize=128, shuffle=true, partial=false
),
# Don't shuffle the validation data
DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false, partial=false),
)
end
Creating a Classifier
We will be extending the Lux.AbstractLuxContainerLayer
type for our custom model since it will contain a LSTM block and a classifier head.
We pass the field names lstm_cell
and classifier
to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters
and Lux.initialstates
.
To understand more about container layers, please look at Container Layer.
struct SpiralClassifier{L,C} <: AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
lstm_cell::L
classifier::C
end
We won't define the model from scratch but rather use the Lux.LSTMCell
and Lux.Dense
.
function SpiralClassifier(in_dims, hidden_dims, out_dims)
return SpiralClassifier(
LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid)
)
end
We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims)
– instead of defining the following. But let's still do it for the sake of it.
Now we need to define the behavior of the Classifier when it is invoked.
function (s::SpiralClassifier)(
x::AbstractArray{T,3}, ps::NamedTuple, st::NamedTuple
) where {T}
# First we will have to run the sequence through the LSTM Cell
# The first call to LSTM Cell will create the initial hidden state
# See that the parameters and states are automatically populated into a field called
# `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
# and `Iterators.peel` to split out the first element for LSTM initialization.
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
(y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
# Now that we have the hidden state and memory in `carry` we will pass the input and
# `carry` jointly
for x in x_rest
(y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
end
# After running through the sequence we will pass the output through the classifier
y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
# Finally remember to create the updated state
st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
return vec(y), st
end
Using the @compact
API
We can also define the model using the Lux.@compact
API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
lstm_cell = LSTMCell(in_dims => hidden_dims)
classifier = Dense(hidden_dims => out_dims, sigmoid)
return @compact(; lstm_cell, classifier) do x::AbstractArray{T,3} where {T}
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
y, carry = lstm_cell(x_init)
for x in x_rest
y, carry = lstm_cell((x, carry))
end
@return vec(classifier(y))
end
end
Defining Accuracy, Loss and Optimiser
Now let's define the binary cross-entropy loss. Typically it is recommended to use logitbinarycrossentropy
since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy
.
const lossfn = BinaryCrossEntropyLoss()
function compute_loss(model, ps, st, (x, y))
ŷ, st_ = model(x, ps, st)
loss = lossfn(ŷ, y)
return loss, st_, (; y_pred=ŷ)
end
matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
Training the Model
function main(model_type)
dev = reactant_device()
cdev = cpu_device()
# Get the dataloaders
train_loader, val_loader = dev(get_dataloaders())
# Create the model
model = model_type(2, 8, 1)
ps, st = dev(Lux.setup(Random.default_rng(), model))
train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
model_compiled = if dev isa ReactantDevice
Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile model(first(train_loader)[1], ps, Lux.testmode(st))
end
else
model
end
ad = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()
for epoch in 1:25
# Train the model
total_loss = 0.0f0
total_samples = 0
for (x, y) in train_loader
(_, loss, _, train_state) = Training.single_train_step!(
ad, lossfn, (x, y), train_state
)
total_loss += loss * length(y)
total_samples += length(y)
end
@printf("Epoch [%3d]: Loss %4.5f\n", epoch, total_loss / total_samples)
# Validate the model
total_acc = 0.0f0
total_loss = 0.0f0
total_samples = 0
st_ = Lux.testmode(train_state.states)
for (x, y) in val_loader
ŷ, st_ = model_compiled(x, train_state.parameters, st_)
ŷ, y = cdev(ŷ), cdev(y)
total_acc += accuracy(ŷ, y) * length(y)
total_loss += lossfn(ŷ, y) * length(y)
total_samples += length(y)
end
@printf(
"Validation:\tLoss %4.5f\tAccuracy %4.5f\n",
total_loss / total_samples,
total_acc / total_samples
)
end
return cpu_device()((train_state.parameters, train_state.states))
end
ps_trained, st_trained = main(SpiralClassifier)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-15/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
AssertionError("Could not find registered platform with name: \"cuda\". Available platform names are: ")
Epoch [ 1]: Loss 0.56034
Validation: Loss 0.49995 Accuracy 1.00000
Epoch [ 2]: Loss 0.47919
Validation: Loss 0.43457 Accuracy 1.00000
Epoch [ 3]: Loss 0.41732
Validation: Loss 0.37449 Accuracy 1.00000
Epoch [ 4]: Loss 0.36183
Validation: Loss 0.31881 Accuracy 1.00000
Epoch [ 5]: Loss 0.30706
Validation: Loss 0.26932 Accuracy 1.00000
Epoch [ 6]: Loss 0.26053
Validation: Loss 0.22607 Accuracy 1.00000
Epoch [ 7]: Loss 0.21532
Validation: Loss 0.18761 Accuracy 1.00000
Epoch [ 8]: Loss 0.18034
Validation: Loss 0.15336 Accuracy 1.00000
Epoch [ 9]: Loss 0.14510
Validation: Loss 0.12397 Accuracy 1.00000
Epoch [ 10]: Loss 0.11857
Validation: Loss 0.09905 Accuracy 1.00000
Epoch [ 11]: Loss 0.09325
Validation: Loss 0.07704 Accuracy 1.00000
Epoch [ 12]: Loss 0.07055
Validation: Loss 0.05727 Accuracy 1.00000
Epoch [ 13]: Loss 0.05224
Validation: Loss 0.04244 Accuracy 1.00000
Epoch [ 14]: Loss 0.03941
Validation: Loss 0.03247 Accuracy 1.00000
Epoch [ 15]: Loss 0.03041
Validation: Loss 0.02610 Accuracy 1.00000
Epoch [ 16]: Loss 0.02476
Validation: Loss 0.02157 Accuracy 1.00000
Epoch [ 17]: Loss 0.02061
Validation: Loss 0.01812 Accuracy 1.00000
Epoch [ 18]: Loss 0.01751
Validation: Loss 0.01541 Accuracy 1.00000
Epoch [ 19]: Loss 0.01479
Validation: Loss 0.01334 Accuracy 1.00000
Epoch [ 20]: Loss 0.01293
Validation: Loss 0.01174 Accuracy 1.00000
Epoch [ 21]: Loss 0.01157
Validation: Loss 0.01048 Accuracy 1.00000
Epoch [ 22]: Loss 0.01030
Validation: Loss 0.00948 Accuracy 1.00000
Epoch [ 23]: Loss 0.00935
Validation: Loss 0.00867 Accuracy 1.00000
Epoch [ 24]: Loss 0.00865
Validation: Loss 0.00799 Accuracy 1.00000
Epoch [ 25]: Loss 0.00796
Validation: Loss 0.00741 Accuracy 1.00000
We can also train the compact model with the exact same code!
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-15/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Epoch [ 1]: Loss 0.48760
Validation: Loss 0.41910 Accuracy 1.00000
Epoch [ 2]: Loss 0.34760
Validation: Loss 0.29167 Accuracy 1.00000
Epoch [ 3]: Loss 0.25437
Validation: Loss 0.23289 Accuracy 1.00000
Epoch [ 4]: Loss 0.20502
Validation: Loss 0.19176 Accuracy 1.00000
Epoch [ 5]: Loss 0.16703
Validation: Loss 0.15554 Accuracy 1.00000
Epoch [ 6]: Loss 0.13056
Validation: Loss 0.11228 Accuracy 1.00000
Epoch [ 7]: Loss 0.09093
Validation: Loss 0.07573 Accuracy 1.00000
Epoch [ 8]: Loss 0.06429
Validation: Loss 0.05762 Accuracy 1.00000
Epoch [ 9]: Loss 0.04965
Validation: Loss 0.04575 Accuracy 1.00000
Epoch [ 10]: Loss 0.03996
Validation: Loss 0.03729 Accuracy 1.00000
Epoch [ 11]: Loss 0.03271
Validation: Loss 0.03073 Accuracy 1.00000
Epoch [ 12]: Loss 0.02712
Validation: Loss 0.02544 Accuracy 1.00000
Epoch [ 13]: Loss 0.02261
Validation: Loss 0.02104 Accuracy 1.00000
Epoch [ 14]: Loss 0.01875
Validation: Loss 0.01744 Accuracy 1.00000
Epoch [ 15]: Loss 0.01578
Validation: Loss 0.01472 Accuracy 1.00000
Epoch [ 16]: Loss 0.01352
Validation: Loss 0.01279 Accuracy 1.00000
Epoch [ 17]: Loss 0.01193
Validation: Loss 0.01139 Accuracy 1.00000
Epoch [ 18]: Loss 0.01072
Validation: Loss 0.01030 Accuracy 1.00000
Epoch [ 19]: Loss 0.00978
Validation: Loss 0.00939 Accuracy 1.00000
Epoch [ 20]: Loss 0.00893
Validation: Loss 0.00860 Accuracy 1.00000
Epoch [ 21]: Loss 0.00820
Validation: Loss 0.00789 Accuracy 1.00000
Epoch [ 22]: Loss 0.00753
Validation: Loss 0.00725 Accuracy 1.00000
Epoch [ 23]: Loss 0.00693
Validation: Loss 0.00666 Accuracy 1.00000
Epoch [ 24]: Loss 0.00637
Validation: Loss 0.00612 Accuracy 1.00000
Epoch [ 25]: Loss 0.00587
Validation: Loss 0.00563 Accuracy 1.00000
Saving the Model
We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.
@save "trained_model.jld2" ps_trained st_trained
Let's try loading the model
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
:ps_trained
:st_trained
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
This page was generated using Literate.jl.