Not Run on CI
This tutorial is not run on CI to reduce the computational burden. If you encounter any issues, please open an issue on the Lux.jl repository.
Qwen3 Implementation from Scratch
This is an implementation of Qwen 3 (blog and technical report) from scratch based on the pytorch implementation developed in Pytorch under the Apache License 2.0.
Package Imports
julia
using BFloat16s, ConcreteStructs, LinearAlgebra, Lux, Random, Reactant
using HuggingFaceTokenizers, PythonCall, SafeTensors, Scratch, JSON3
const huggingface_hub = pyimport("huggingface_hub")
Qwen3 Configuration
julia
@kwdef struct Qwen3Config{F}
version::String
vocab_size::Int
context_length::Int
emb_dim::Int
n_heads::Int
n_layers::Int
hidden_dim::Int
head_dim::Int
n_kv_groups::Int
rope_base::Float32
dtype::F
reasoning_model::Bool = true
end
function Qwen3Config(version::String; kwargs...)
if version == "0.6B"
return Qwen3Config(;
version,
vocab_size=151_936,
context_length=40_960,
emb_dim=1024,
n_heads=16,
n_layers=28,
hidden_dim=3072,
head_dim=128,
n_kv_groups=8,
rope_base=1.0f6,
dtype=bf16,
kwargs...,
)
elseif version == "1.7B"
return Qwen3Config(;
version,
vocab_size=151_936,
context_length=40_960,
emb_dim=2048,
n_heads=16,
n_layers=28,
hidden_dim=6144,
head_dim=128,
n_kv_groups=8,
rope_base=1.0f6,
dtype=bf16,
kwargs...,
)
elseif version == "4B"
return Qwen3Config(;
version,
vocab_size=151_936,
context_length=40_960,
emb_dim=2560,
n_heads=32,
n_layers=36,
hidden_dim=9728,
head_dim=128,
n_kv_groups=8,
rope_base=1.0f6,
dtype=bf16,
kwargs...,
)
elseif version == "8B"
return Qwen3Config(;
version,
vocab_size=151_936,
context_length=40_960,
emb_dim=4096,
n_heads=32,
n_layers=36,
hidden_dim=12288,
head_dim=128,
n_kv_groups=8,
rope_base=1.0f6,
dtype=bf16,
kwargs...,
)
elseif version == "14B"
return Qwen3Config(;
version,
vocab_size=151_936,
context_length=40_960,
emb_dim=5120,
n_heads=40,
n_layers=40,
hidden_dim=17408,
head_dim=128,
n_kv_groups=8,
rope_base=1.0f6,
dtype=bf16,
kwargs...,
)
elseif version == "32B"
return Qwen3Config(;
version,
vocab_size=151_936,
context_length=40_960,
emb_dim=5120,
n_heads=64,
n_layers=64,
hidden_dim=25600,
head_dim=128,
n_kv_groups=8,
rope_base=1.0f6,
dtype=bf16,
kwargs...,
)
end
throw(ArgumentError("Unknown Qwen3 version $version"))
end
fn_to_dtype(::Type{T}) where {T} = T
fn_to_dtype(::typeof(f16)) = Float16
fn_to_dtype(::typeof(f32)) = Float32
fn_to_dtype(::typeof(f64)) = Float64
fn_to_dtype(::typeof(bf16)) = BFloat16
Model Definition
julia
function Qwen3MLP(cfg::Qwen3Config)
return Chain(;
proj=Parallel(
.*;
gate_proj=Dense(cfg.emb_dim => cfg.hidden_dim, swish; use_bias=false),
up_proj=Dense(cfg.emb_dim => cfg.hidden_dim; use_bias=false),
),
down_proj=Dense(cfg.hidden_dim => cfg.emb_dim; use_bias=false),
name="Qwen3MLP",
)
end
Qwen3RMSNorm(emb_dim::Int, eps) = AlternatePrecision{Float32}(RMSNorm(emb_dim; epsilon=eps))
@concrete struct GroupedQueryAttention <: AbstractLuxContainerLayer{(
:q_proj, :k_proj, :v_proj, :o_proj, :q_norm, :k_norm
)}
q_proj
k_proj
v_proj
o_proj
q_norm
k_norm
d_in::Int
num_heads::Int
num_kv_groups::Int
head_dim::Int
end
function GroupedQueryAttention(d_in, num_heads, num_kv_groups; head_dim=nothing)
@assert num_heads % num_kv_groups == 0 "num_heads must be divisible by num_kv_groups"
if head_dim === nothing
@assert d_in % num_heads == 0 "`d_in` must be divisible by `num_heads` if \
`head_dim` is not set"
head_dim = d_in ÷ num_heads
end
d_out = num_heads * head_dim
return GroupedQueryAttention(
Dense(d_in, d_out; use_bias=false),
Dense(d_in, num_kv_groups * head_dim; use_bias=false),
Dense(d_in, num_kv_groups * head_dim; use_bias=false),
Dense(d_out, d_in; use_bias=false),
Qwen3RMSNorm(head_dim, 1.0f-6),
Qwen3RMSNorm(head_dim, 1.0f-6),
d_in,
num_heads,
num_kv_groups,
head_dim,
)
end
function apply_rope(x::AbstractArray{T}, cos_cache, sin_cache) where {T}
return T.(apply_rotary_embedding(x, cos_cache, sin_cache; seq_dim=3))
end
function (attn::GroupedQueryAttention)((x, cos_cache, sin_cache), ps, st::NamedTuple)
_, num_tokens, B = size(x)
# apply projections
queries, st_q_proj = attn.q_proj(x, ps.q_proj, st.q_proj)
keys, st_k_proj = attn.k_proj(x, ps.k_proj, st.k_proj)
values, st_v_proj = attn.v_proj(x, ps.v_proj, st.v_proj)
# reshape and permute to (head_dim, num_heads/num_kv_groups, num_tokens, batch)
queries = reshape(queries, attn.head_dim, attn.num_heads, num_tokens, B)
keys = reshape(keys, attn.head_dim, attn.num_kv_groups, num_tokens, B)
values = reshape(values, attn.head_dim, attn.num_kv_groups, num_tokens, B)
# apply normalization
queries, st_q_norm = attn.q_norm(queries, ps.q_norm, st.q_norm)
keys, st_k_norm = attn.k_norm(keys, ps.k_norm, st.k_norm)
# apply RoPE
queries = apply_rope(queries, cos_cache, sin_cache)
keys = apply_rope(keys, cos_cache, sin_cache)
# attention
context = reshape(
scaled_dot_product_attention(
queries, keys, values; head_dim=1, token_dim=3, is_causal=true
)[1],
attn.head_dim * attn.num_heads,
num_tokens,
B,
)
# output projection
proj, st_o_proj = attn.o_proj(context, ps.o_proj, st.o_proj)
return (
proj,
(;
q_proj=st_q_proj,
k_proj=st_k_proj,
v_proj=st_v_proj,
o_proj=st_o_proj,
q_norm=st_q_norm,
k_norm=st_k_norm,
),
)
end
@concrete struct Qwen3Attention <: AbstractLuxContainerLayer{(
:self_attn, :mlp, :input_layernorm, :post_attention_layernorm
)}
self_attn <: GroupedQueryAttention
mlp
input_layernorm
post_attention_layernorm
end
function Qwen3Attention(cfg::Qwen3Config)
return Qwen3Attention(
GroupedQueryAttention(cfg.emb_dim, cfg.n_heads, cfg.n_kv_groups; cfg.head_dim),
Qwen3MLP(cfg),
Qwen3RMSNorm(cfg.emb_dim, 1.0f-6),
Qwen3RMSNorm(cfg.emb_dim, 1.0f-6),
)
end
function (block::Qwen3Attention)((x, cos_cache, sin_cache), ps, st::NamedTuple)
# shortcut connection for attention block
shortcut = x
x, st_norm1 = block.input_layernorm(x, ps.input_layernorm, st.input_layernorm)
x, st_attn = block.self_attn((x, cos_cache, sin_cache), ps.self_attn, st.self_attn)
x = x .+ shortcut
# shortcut connection for feed-forward block
shortcut = x
x, st_norm2 = block.post_attention_layernorm(
x, ps.post_attention_layernorm, st.post_attention_layernorm
)
x, st_ff = block.mlp(x, ps.mlp, st.mlp)
x = x .+ shortcut
return (
x,
(;
self_attn=st_attn,
mlp=st_ff,
input_layernorm=st_norm1,
post_attention_layernorm=st_norm2,
),
)
end
@concrete struct Qwen3 <:
AbstractLuxContainerLayer{(:embed_tokens, :blocks, :norm, :lm_head)}
embed_tokens
blocks
norm
lm_head
cfg::Qwen3Config
end
function Qwen3(cfg::Qwen3Config)
return Qwen3(
Embedding(cfg.vocab_size => cfg.emb_dim),
Tuple([Qwen3Attention(cfg) for _ in 1:(cfg.n_layers)]),
Qwen3RMSNorm(cfg.emb_dim, 1.0f-6),
Dense(cfg.emb_dim, cfg.vocab_size; use_bias=false),
cfg,
)
end
function LuxCore.initialstates(rng::AbstractRNG, m::Qwen3)
head_dim = m.cfg.head_dim === nothing ? m.cfg.emb_dim ÷ m.cfg.n_heads : m.cfg.head_dim
(; cos_cache, sin_cache) = compute_rotary_embedding_params(
head_dim, m.cfg.context_length; base=m.cfg.rope_base, dtype=Float32
)
return (;
cos_cache,
sin_cache,
embed_tokens=LuxCore.initialstates(rng, m.embed_tokens),
blocks=LuxCore.initialstates(rng, m.blocks),
norm=LuxCore.initialstates(rng, m.norm),
lm_head=LuxCore.initialstates(rng, m.lm_head),
)
end
function (qwen3::Qwen3)(in_idx, ps, st::NamedTuple)
x, st_embed_tokens = qwen3.embed_tokens(in_idx, ps.embed_tokens, st.embed_tokens)
st_blocks = ()
for (i, block) in enumerate(qwen3.blocks)
x, st_block_new = block((x, st.cos_cache, st.sin_cache), ps.blocks[i], st.blocks[i])
st_blocks = (st_blocks..., st_block_new)
end
x, st_norm = qwen3.norm(x, ps.norm, st.norm)
logits, st_lm_head = qwen3.lm_head(
fn_to_dtype(qwen3.cfg.dtype).(x), ps.lm_head, st.lm_head
)
return (
logits,
(;
cos_cache=st.cos_cache,
sin_cache=st.sin_cache,
embed_tokens=st_embed_tokens,
blocks=st_blocks,
norm=st_norm,
lm_head=st_lm_head,
),
)
end
Model Weights and Tokenizer from HuggingFace
julia
function download_qwen3_weights_from_huggingface(cfg::Qwen3Config)
return download_qwen3_weights_from_huggingface(cfg.reasoning_model, cfg.version)
end
function download_qwen3_weights_from_huggingface(use_reasoning_model::Bool, version::String)
repo_id = "Qwen/Qwen3-$(version)" * (use_reasoning_model ? "" : "-Base")
local_dir = @get_scratch!("Qwen3-$(version)-$(use_reasoning_model)")
tokenizer_file = huggingface_hub.hf_hub_download(;
repo_id=repo_id, filename="tokenizer.json", local_dir=local_dir
)
if version == "0.6B"
weights_file = huggingface_hub.hf_hub_download(;
repo_id=repo_id, filename="model.safetensors", local_dir=local_dir
)
weights_dict = load_safetensors(string(weights_file))
else
repo_dir = huggingface_hub.snapshot_download(; repo_id=repo_id, local_dir=local_dir)
index_path = joinpath(string(repo_dir), "model.safetensors.index.json")
index = JSON3.read(index_path)
weights_dict = Dict()
for filename in Set(values(index["weight_map"]))
shard_path = joinpath(string(repo_dir), filename)
shard = load_safetensors(shard_path)
merge!(weights_dict, shard)
end
end
return weights_dict, string(tokenizer_file), repo_id
end
Qwen3 Tokenizer
julia
struct Qwen3Tokenizer
tokenizer::Tokenizer
special_to_id::Dict{String,Int32}
pad_token_id::Int32
eos_token_id::Int32
apply_chat_template::Bool
add_generation_prompt::Bool
add_thinking::Bool
end
function Base.show(io::IO, tokenizer::Qwen3Tokenizer)
return print(
io,
"Qwen3Tokenizer(apply_chat_template=$(tokenizer.apply_chat_template), add_generation_prompt=$(tokenizer.add_generation_prompt), add_thinking=$(tokenizer.add_thinking))",
)
end
const SPECIALS = [
"<|endoftext|>",
"<|im_start|>",
"<|im_end|>",
"<|object_ref_start|>",
"<|object_ref_end|>",
"<|box_start|>",
"<|box_end|>",
"<|quad_start|>",
"<|quad_end|>",
"<|vision_start|>",
"<|vision_end|>",
"<|vision_pad|>",
"<|image_pad|>",
"<|video_pad|>",
]
const SPLIT_RE = r"(<\|[^>]+?\|>)"
token_to_id(tokenizer::Qwen3Tokenizer, s) = token_to_id(tokenizer.tokenizer, s)
function token_to_id(tokenizer::Tokenizer, s)
return pyconvert(Int32, tokenizer.py_tokenizer.token_to_id(s)) + Int32(1)
end
function split_with_delims(text::String, re::Regex)
parts = String[]
last_end = 1
for m in eachmatch(re, text)
if m.offset > last_end
push!(parts, text[last_end:(m.offset - 1)])
elseif m.offset == 1
push!(parts, "")
end
push!(parts, m.match)
last_end = m.offset + length(m.match)
end
if last_end ≤ lastindex(text)
push!(parts, text[last_end:end])
end
return parts
end
function Qwen3Tokenizer(
tokenizer_file_path::String;
repo_id=nothing,
apply_chat_template::Bool=true,
add_generation_prompt::Bool=false,
add_thinking::Bool=false,
)
tok = HuggingFaceTokenizers.from_file(Tokenizer, tokenizer_file_path)
special_to_id = Dict(s => token_to_id(tok, s) for s in SPECIALS)
pad_token_id = special_to_id["<|endoftext|>"]
eos_token_id = pad_token_id
if repo_id !== nothing && !occursin("Base", repo_id)
eos_token = "<|im_end|>"
else
eos_token = "<|endoftext|>"
end
if haskey(special_to_id, eos_token)
eos_token_id = special_to_id[eos_token]
end
return Qwen3Tokenizer(
tok,
special_to_id,
pad_token_id,
eos_token_id,
apply_chat_template,
add_generation_prompt,
add_thinking,
)
end
function wrap_chat(tokenizer::Qwen3Tokenizer, user_msg::AbstractString)
s = "<|im_start|>user\n$(user_msg)<|im_end|>\n"
if tokenizer.add_generation_prompt
s *= "<|im_start|>assistant"
if tokenizer.add_thinking
s *= "\n"
else
s *= "\n<think>\n\n</think>\n\n"
end
end
return s
end
function HuggingFaceTokenizers.encode(
tok::Qwen3Tokenizer, text; chat_wrapped::Bool=tok.apply_chat_template
)
stripped = strip(text)
if haskey(tok.special_to_id, stripped) && !occursin('\n', stripped)
return [tok.special_to_id[stripped]]
end
chat_wrapped && (text = wrap_chat(tok, text))
ids = Int32[]
for part in filter(!isempty, split_with_delims(text, SPLIT_RE))
if haskey(tok.special_to_id, part)
push!(ids, tok.special_to_id[part])
else
append!(ids, encode(tok.tokenizer, string(part)).ids .+ Int16(1))
end
end
return ids
end
function HuggingFaceTokenizers.decode(tok::Qwen3Tokenizer, ids::Vector{<:Integer})
return decode(tok.tokenizer, ids .- Int16(1); skip_special_tokens=false)
end
Pretrained Model Weights
julia
get_weights_tensor(tensor::AbstractArray, ::Type{T}) where {T} = collect(T, tensor)
function get_weights_tensor(dict, key, dtype::Type{T}, dev; permute::Bool=false) where {T}
tensor = dict[key]
if permute
tensor = permutedims(tensor, Tuple(reverse(1:ndims(tensor))))
end
return get_weights_tensor(tensor, dtype) |> dev
end
function load_weights_from_dict(weights_dict, cfg::Qwen3Config, dev)
dtype = fn_to_dtype(cfg.dtype)
function get_tensor(key; kwargs...)
return get_weights_tensor(weights_dict, key, dtype, dev; kwargs...)
end
embed_tokens = (; weight=get_tensor("model.embed_tokens.weight"; permute=true))
blocks = Vector{Any}(undef, cfg.n_layers)
for l in 1:(cfg.n_layers)
prefix = "model.layers.$(l - 1)"
sa_prefix = "$(prefix).self_attn"
blocks[l] = (;
self_attn=merge(
NamedTuple(
k => (; weight=get_tensor("$(sa_prefix).$k.weight")) for
k in (:q_proj, :k_proj, :v_proj, :o_proj)
),
NamedTuple(
k => (; scale=get_tensor("$(sa_prefix).$k.weight")) for
k in (:q_norm, :k_norm)
),
),
mlp=(;
proj=(;
gate_proj=(; weight=get_tensor("$(prefix).mlp.gate_proj.weight")),
up_proj=(; weight=get_tensor("$(prefix).mlp.up_proj.weight")),
),
down_proj=(; weight=get_tensor("$(prefix).mlp.down_proj.weight")),
),
input_layernorm=(; scale=get_tensor("$(prefix).input_layernorm.weight")),
post_attention_layernorm=(;
scale=get_tensor("$(prefix).post_attention_layernorm.weight")
),
)
end
blocks = Tuple(blocks)
norm = (; scale=get_weights_tensor(weights_dict, "model.norm.weight", dtype, dev))
if haskey(weights_dict, "lm_head.weight")
lm_head = (; weight=get_weights_tensor(weights_dict, "lm_head.weight", dtype, dev))
else
# Weight tying with the embedding matrix. We will share the weights here to
# reduce memory usage.
lm_head = (; weight=transpose(embed_tokens.weight))
end
return (; embed_tokens, blocks, norm, lm_head)
end
function setup_model(
version::String, dev; weights_dict::Union{Nothing,Dict}=nothing, kwargs...
)
return setup_model(Qwen3Config(version; kwargs...), dev; weights_dict)
end
function setup_model(cfg::Qwen3Config, dev; weights_dict::Union{Nothing,Dict}=nothing)
model = Qwen3(cfg)
st = Lux.initialstates(Random.default_rng(), model) |> dev
if weights_dict !== nothing
ps = load_weights_from_dict(weights_dict, cfg, dev)
else
ps = Lux.initialparameters(Random.default_rng(), model) |> dev
ps = ps |> cfg.dtype
end
return model, ps, st
end
Running the model without dynamic sizes
julia
function get_padded_size(seq_len::Int, context_length::Int)
return min(max(512, nextpow(2, seq_len)), context_length)
end
function padded_input_and_mask_len(x::AbstractMatrix, v, cfg::Qwen3Config, pad_token_id)
return padded_input_and_mask_len(
x, v, get_padded_size(size(x, 1) + v !== nothing, cfg.context_length), pad_token_id
)
end
function padded_input_and_mask_len(x::AbstractMatrix, v, padded_sz::Int, pad_token_id)
if padded_sz > size(x, 1)
x_padded = similar(x, (padded_sz, size(x, 2)))
x_padded[1:size(x, 1), :] .= x
if v === nothing
x_padded[(size(x, 1) + 1):end, :] .= pad_token_id
else
x_padded[(size(x, 1) + 1), :] = v[1, :]
x_padded[(size(x, 1) + 2):end, :] .= pad_token_id
end
else
x_padded = x
end
return (
x_padded,
Reactant.TracedUtils.promote_to(
Reactant.TracedRNumber{Int32}, padded_sz - (size(x, 1) + (v !== nothing))
),
)
end
Helpers to generate text
julia
function predict_next_token(
model, token_ids::AbstractMatrix{T}, input_mask_len, ps, st
) where {T}
logits, stₙ = model(token_ids, ps, st)
predictions = T.(argmax(logits[:, end - input_mask_len, :]; dims=1))
predictions = mod1.(predictions, T(size(logits, 1)))
return predictions, stₙ
end
function update_token_ids_and_mask!(
padded_token_ids, input_mask_len, cur_num_tokens, next_token
)
next_token_idx = safe_increment(cur_num_tokens)
padded_token_ids[next_token_idx, :] = next_token[1, :]
return input_mask_len - eltype(input_mask_len)(1), next_token_idx
end
function update_token_ids_with_shift!(token_ids, next_token)
token_ids[1:(end - 1), :] = token_ids[2:end, :]
token_ids[end, :] = next_token[1, :]
return nothing
end
safe_increment(x) = x + one(x)
mutable struct CachedReactantThunks
cache::Dict{Qwen3Config,Dict{Int,NTuple{3,Reactant.Compiler.Thunk}}}
increment_fn::Union{Nothing,Reactant.Compiler.Thunk}
end
function CachedReactantThunks()
return CachedReactantThunks(
Dict{Qwen3Config,Dict{Int,NTuple{3,Reactant.Compiler.Thunk}}}(), nothing
)
end
function cache_and_retrieve!(
cache::CachedReactantThunks,
len::Integer,
model::Qwen3,
padded_token_ids,
input_mask_len,
ps,
st,
next_token,
cur_num_tokens_traced,
)
if haskey(cache.cache, model.cfg) && haskey(cache.cache[model.cfg], len)
return cache.cache[model.cfg][len]
end
println()
@warn "Compiling Qwen3 generation loop for $(model.cfg.version) with $(len) tokens. \
This might take a while... (However this is only done once per model per length)"
predict_next_token_compiled = @compile predict_next_token(
model, padded_token_ids, input_mask_len, ps, st
)
update_fn1! = @compile update_token_ids_and_mask!(
padded_token_ids, input_mask_len, cur_num_tokens_traced, next_token
)
update_fn2! = @compile update_token_ids_with_shift!(padded_token_ids, next_token)
if !haskey(cache.cache, model.cfg)
cache.cache[model.cfg] = Dict{Int,NTuple{3,Reactant.Compiler.Thunk}}()
end
return cache.cache[model.cfg][len] = (
predict_next_token_compiled, update_fn1!, update_fn2!
)
end
const CACHED_THUNKS = CachedReactantThunks()
generate_text(args...; kwargs...) = generate_text!(CACHED_THUNKS, args...; kwargs...)
function generate_text!(
compile_cache::CachedReactantThunks,
model::Qwen3,
prompt::String,
ps,
st,
max_new_tokens,
tokenizer,
)
token_ids = Reactant.to_rarray(reshape(encode(tokenizer, prompt), :, 1))
# TODO: compile the generation loop with Reactant
# TODO: implement some simple KV caching
cur_num_tokens = size(token_ids, 1)
max_context_length = model.cfg.context_length
cur_compiled_fn_token_len = get_padded_size(cur_num_tokens, max_context_length)
padded_token_ids, input_mask_len = @jit padded_input_and_mask_len(
token_ids, nothing, cur_compiled_fn_token_len, tokenizer.pad_token_id
)
cur_num_tokens_traced = ConcreteRNumber{Int32}(cur_num_tokens)
next_token = get_device(ps)(rand(Int32, 1, size(padded_token_ids, 2)))
(predict_next_token_compiled, update_fn1!, update_fn2!) = cache_and_retrieve!(
compile_cache,
cur_compiled_fn_token_len,
model,
padded_token_ids,
input_mask_len,
ps,
st,
next_token,
cur_num_tokens_traced,
)
if compile_cache.increment_fn === nothing
compile_cache.increment_fn = @compile safe_increment(cur_num_tokens_traced)
end
start_time = time()
compile_time = 0.0
ntokens_generated = 0
for _ in 1:max_new_tokens
new_compiled_fn_token_len = get_padded_size(cur_num_tokens, max_context_length)
if new_compiled_fn_token_len != cur_compiled_fn_token_len
compile_start_time = time()
cur_compiled_fn_token_len = new_compiled_fn_token_len
padded_token_ids, input_mask_len = @jit padded_input_and_mask_len(
padded_token_ids,
next_token,
cur_compiled_fn_token_len,
tokenizer.pad_token_id,
)
(predict_next_token_compiled, update_fn1!, update_fn2!) = cache_and_retrieve!(
compile_cache,
cur_compiled_fn_token_len,
model,
padded_token_ids,
input_mask_len,
ps,
st,
next_token,
cur_num_tokens_traced,
)
compile_time += time() - compile_start_time
end
next_token, st = predict_next_token_compiled(
model, padded_token_ids, input_mask_len, ps, st
)
ntokens_generated += 1
next_token_jl = vec(Array(next_token))
if tokenizer.eos_token_id !== nothing &&
all(next_token_jl .== tokenizer.eos_token_id)
break
end
print(decode(tokenizer, next_token_jl))
if cur_num_tokens >= max_context_length
update_fn2!(padded_token_ids, next_token)
elseif new_compiled_fn_token_len > cur_num_tokens
input_mask_len, cur_num_tokens_traced = update_fn1!(
padded_token_ids, input_mask_len, cur_num_tokens_traced, next_token
)
else
cur_num_tokens_traced = compile_cache.increment_fn(cur_num_tokens_traced)
end
cur_num_tokens += 1
end
total_time = time() - start_time
println()
return ntokens_generated / (total_time - compile_time)
end
Entry Point
julia
function run_model_selection()
printstyled("Which model do you want to run? \n"; color=:cyan, bold=true)
choices = ["0.6B", "1.7B", "4B", "8B", "14B", "32B"]
for (i, choice) in enumerate(choices)
printstyled(" $(i). $(choice)\n"; color=:light_blue)
end
printstyled(" Enter your choice: "; color=:cyan)
choice = parse(Int, readline(stdin))
if choice ∉ 1:length(choices)
error("Invalid choice: $(choice). Expected an integer between 1 and \
$(length(choices))")
end
printstyled("Do you want to use the reasoning model? [y/N] "; color=:cyan)
reasoning = readline(stdin) == "y"
println()
return choices[choice], reasoning
end
function get_model_and_tokenizer(version, reasoning)
cfg = Qwen3Config(version; reasoning_model=reasoning)
rdev = reactant_device(; force=true)
weights_dict, tokenizer_file, repo_id = download_qwen3_weights_from_huggingface(cfg)
tokenizer = Qwen3Tokenizer(
tokenizer_file;
repo_id,
add_generation_prompt=cfg.reasoning_model,
add_thinking=cfg.reasoning_model,
)
model, ps, st = setup_model(cfg, rdev; weights_dict)
return model, ps, st, tokenizer
end
function main()
@info "Text Generation with Qwen-3 powered by Lux, Reactant & XLA."
version, reasoning = run_model_selection()
model, ps, st, tokenizer = get_model_and_tokenizer(version, reasoning)
while true
printstyled(
"Prompt (type \"exit\" to quit the program or \
\"model selection\" to change the model): ";
color=:cyan,
bold=true,
)
prompt = readline(stdin)
prompt == "exit" && break
if prompt == "model selection"
version, reasoning = run_model_selection()
model, ps, st, tokenizer = get_model_and_tokenizer(version, reasoning)
continue
end
tokens_per_second = generate_text(model, prompt, ps, st, 100_000, tokenizer)
println("\nTokens per second: $tokens_per_second\n\n")
end
return nothing
end
if abspath(PROGRAM_FILE) == @__FILE__
main()
end
This page was generated using Literate.jl.