Not Run on CI
This tutorial is not run on CI to reduce the computational burden. If you encounter any issues, please open an issue on the Lux.jl repository.
Denoising Diffusion Implicit Model (DDIM)
Lux.jl implementation of Denoising Diffusion Implicit Models (arXiv:2010.02502). The model generates images from Gaussian noises by <em>denoising</em> iteratively.
Package Imports
julia
ENV["XLA_REACTANT_GPU_MEM_FRACTION"] = get(ENV, "XLA_REACTANT_GPU_MEM_FRACTION", "0.98")
using ConcreteStructs,
ArgParse,
DataAugmentation,
DataDeps,
Dates,
Enzyme,
FileIO,
ImageCore,
ImageShow,
JLD2,
Lux,
MLUtils,
Optimisers,
ParameterSchedulers,
ProgressTables,
Printf,
Random,
Reactant,
Statistics,
TensorBoardLogger,
OhMyThreads
const IN_VSCODE = isdefined(Main, :VSCodeServer)Model Definition
This DDIM implementation follows the Keras example. Embed noise variances to embedding.
julia
function sinusoidal_embedding(
x::AbstractArray{T,4}, min_freq, max_freq, embedding_dims::Int
) where {T}
if size(x)[1:3] != (1, 1, 1)
throw(DimensionMismatch("Input shape must be (1, 1, 1, batch)"))
end
lower, upper = T(log(min_freq)), T(log(max_freq))
n = embedding_dims ÷ 2
freqs = exp.(reshape(range(lower, upper; length=n), 1, 1, n, 1))
x_ = 2 .* x .* freqs
return cat(sinpi.(x_), cospi.(x_); dims=Val(3))
end
@concrete struct ResidualBlock <: AbstractLuxWrapperLayer{:layer}
layer
end
function ResidualBlock(in_channels::Int, out_channels::Int)
return ResidualBlock(
Parallel(
+,
if in_channels == out_channels
NoOpLayer()
else
Conv((1, 1), in_channels => out_channels; pad=SamePad())
end,
Chain(
BatchNorm(in_channels; affine=false),
Conv((3, 3), in_channels => out_channels, swish; pad=SamePad()),
Conv((3, 3), out_channels => out_channels; pad=SamePad()),
),
),
)
end
@concrete struct DownsampleBlock <: AbstractLuxContainerLayer{(:residual_blocks, :pool)}
residual_blocks
pool
end
function DownsampleBlock(in_channels::Int, out_channels::Int, block_depth::Int)
residual_blocks = Tuple([
ResidualBlock(ifelse(i == 1, in_channels, out_channels), out_channels) for
i in 1:block_depth
])
return DownsampleBlock(residual_blocks, MeanPool((2, 2)))
end
function (d::DownsampleBlock)(x::AbstractArray, ps, st::NamedTuple)
skips = (x,)
st_residual_blocks = ()
for i in eachindex(d.residual_blocks)
y, st_new = d.residual_blocks[i](
last(skips), ps.residual_blocks[i], st.residual_blocks[i]
)
skips = (skips..., y)
st_residual_blocks = (st_residual_blocks..., st_new)
end
y, st_pool = d.pool(last(skips), ps.pool, st.pool)
return (y, skips), (; residual_blocks=st_residual_blocks, pool=st_pool)
end
@concrete struct UpsampleBlock <: AbstractLuxContainerLayer{(:residual_blocks, :upsample)}
residual_blocks
upsample
end
function UpsampleBlock(in_channels::Int, out_channels::Int, block_depth::Int)
residual_blocks = Tuple([
ResidualBlock(
ifelse(i == 1, in_channels + out_channels, out_channels * 2), out_channels
) for i in 1:block_depth
])
return UpsampleBlock(residual_blocks, Upsample(:bilinear; scale=2))
end
function (u::UpsampleBlock)((x, skips), ps, st::NamedTuple)
x, st_upsample = u.upsample(x, ps.upsample, st.upsample)
y, st_residual_blocks = x, ()
for i in eachindex(u.residual_blocks)
y, st_new = u.residual_blocks[i](
cat(y, skips[end - i + 1]; dims=Val(3)),
ps.residual_blocks[i],
st.residual_blocks[i],
)
st_residual_blocks = (st_residual_blocks..., st_new)
end
return y, (; residual_blocks=st_residual_blocks, upsample=st_upsample)
end
@concrete struct UNet <: AbstractLuxContainerLayer{(
:conv_in, :conv_out, :down_blocks, :residual_blocks, :up_blocks, :upsample
)}
upsample
conv_in
conv_out
down_blocks
residual_blocks
up_blocks
min_freq
max_freq
embedding_dims
end
function UNet(
image_size::Dims{2};
channels=[32, 64, 96, 128],
block_depth=2,
min_freq=1.0f0,
max_freq=1000.0f0,
embedding_dims=32,
)
upsample = Upsample(:nearest; size=image_size)
conv_in = Conv((1, 1), 3 => channels[1])
conv_out = Conv((1, 1), channels[1] => 3; init_weight=zeros32)
channel_input = embedding_dims + channels[1]
down_blocks = Tuple([
DownsampleBlock(i == 1 ? channel_input : channels[i - 1], channels[i], block_depth)
for i in 1:(length(channels) - 1)
])
residual_blocks = Chain(
[
ResidualBlock(ifelse(i == 1, channels[end - 1], channels[end]), channels[end])
for i in 1:block_depth
]...,
)
reverse!(channels)
up_blocks = Tuple([
UpsampleBlock(in_chs, out_chs, block_depth) for
(in_chs, out_chs) in zip(channels[1:(end - 1)], channels[2:end])
])
return UNet(
upsample,
conv_in,
conv_out,
down_blocks,
residual_blocks,
up_blocks,
min_freq,
max_freq,
embedding_dims,
)
end
function (u::UNet)((noisy_images, noise_variances), ps, st::NamedTuple)
@assert size(noise_variances)[1:3] == (1, 1, 1)
@assert size(noisy_images, 4) == size(noise_variances, 4)
emb, st_upsample = u.upsample(
sinusoidal_embedding(noise_variances, u.min_freq, u.max_freq, u.embedding_dims),
ps.upsample,
st.upsample,
)
tmp, st_conv_in = u.conv_in(noisy_images, ps.conv_in, st.conv_in)
x = cat(tmp, emb; dims=Val(3))
skips_at_each_stage = ()
st_down_blocks = ()
for i in eachindex(u.down_blocks)
(x, skips), st_new = u.down_blocks[i](x, ps.down_blocks[i], st.down_blocks[i])
skips_at_each_stage = (skips_at_each_stage..., skips)
st_down_blocks = (st_down_blocks..., st_new)
end
x, st_residual_blocks = u.residual_blocks(x, ps.residual_blocks, st.residual_blocks)
st_up_blocks = ()
for i in eachindex(u.up_blocks)
x, st_new = u.up_blocks[i](
(x, skips_at_each_stage[end - i + 1]), ps.up_blocks[i], st.up_blocks[i]
)
st_up_blocks = (st_up_blocks..., st_new)
end
x, st_conv_out = u.conv_out(x, ps.conv_out, st.conv_out)
return (
x,
(;
conv_in=st_conv_in,
conv_out=st_conv_out,
down_blocks=st_down_blocks,
residual_blocks=st_residual_blocks,
up_blocks=st_up_blocks,
upsample=st_upsample,
),
)
end
function diffusion_schedules(
diffusion_times::AbstractArray{T,4}, min_signal_rate, max_signal_rate
) where {T}
start_angle = T(acos(max_signal_rate))
end_angle = T(acos(min_signal_rate))
diffusion_angles = @. start_angle + (end_angle - start_angle) * diffusion_times
signal_rates = @. cos(diffusion_angles)
noise_rates = @. sin(diffusion_angles)
return noise_rates, signal_rates
end
function denoise(
unet,
noisy_images::AbstractArray{T1,4},
noise_rates::AbstractArray{T2,4},
signal_rates::AbstractArray{T3,4},
) where {T1,T2,T3}
T = promote_type(T1, T2, T3)
noisy_images = T.(noisy_images)
signal_rates = T.(signal_rates)
pred_noises = unet((noisy_images, noise_rates .^ 2))
pred_images = @. (noisy_images - pred_noises * noise_rates) / signal_rates
return pred_noises, pred_images
end
function denoise!(
pred_images,
unet,
noisy_images::AbstractArray{T1,4},
noise_rates::AbstractArray{T2,4},
signal_rates::AbstractArray{T3,4},
) where {T1,T2,T3}
T = promote_type(T1, T2, T3)
noisy_images = T.(noisy_images)
noise_rates = T.(noise_rates)
signal_rates = T.(signal_rates)
pred_noises = unet((noisy_images, noise_rates .^ 2))
@. pred_images = (noisy_images - pred_noises * noise_rates) / signal_rates
return pred_noises
end
@concrete struct DDIM <: AbstractLuxContainerLayer{(:unet, :bn)}
unet
bn
min_signal_rate
max_signal_rate
image_size::Dims{3}
end
function DDIM(
image_size::Dims{2}, args...; min_signal_rate=0.02f0, max_signal_rate=0.95f0, kwargs...
)
return DDIM(
UNet(image_size, args...; kwargs...),
BatchNorm(3; affine=false, track_stats=true),
min_signal_rate,
max_signal_rate,
(image_size..., 3),
)
end
function Lux.initialstates(rng::AbstractRNG, ddim::DDIM)
rand(rng, 1)
return (;
rng, bn=Lux.initialstates(rng, ddim.bn), unet=Lux.initialstates(rng, ddim.unet)
)
end
function (ddim::DDIM)(x::AbstractArray{T,4}, ps, st::NamedTuple) where {T}
images, st_bn = ddim.bn(x, ps.bn, st.bn)
rng = Lux.replicate(st.rng)
diffusion_times = rand_like(rng, images, (1, 1, 1, size(images, 4)))
noise_rates, signal_rates = diffusion_schedules(
diffusion_times, ddim.min_signal_rate, ddim.max_signal_rate
)
noises = randn_like(rng, images)
noisy_images = @. signal_rates * images + noise_rates * noises
unet = StatefulLuxLayer{true}(ddim.unet, ps.unet, st.unet)
pred_noises, pred_images = denoise(unet, noisy_images, noise_rates, signal_rates)
return ((noises, images, pred_noises, pred_images), (; rng, bn=st_bn, unet=unet.st))
end
# Helper Functions for Image Generation
function generate(model::DDIM, ps, st::NamedTuple, diffusion_steps::Int, num_samples::Int)
rng = Lux.replicate(st.rng)
μ, σ² = st.bn.running_mean, st.bn.running_var
initial_noise = randn_like(rng, μ, (model.image_size..., num_samples))
generated_images = reverse_diffusion(model, initial_noise, ps, st, diffusion_steps)
return clamp01.(denormalize(generated_images, μ, σ², model.bn.epsilon))
end
function reverse_diffusion_single_step!(
pred_images,
noisy_images,
step,
step_size,
unet,
ps,
st,
ones,
min_signal_rate,
max_signal_rate,
)
diffusion_times = ones .- step_size * step
noise_rates, signal_rates = diffusion_schedules(
diffusion_times, min_signal_rate, max_signal_rate
)
sunet = StatefulLuxLayer{true}(unet, ps, st)
pred_noises = denoise!(pred_images, sunet, noisy_images, noise_rates, signal_rates)
next_diffusion_times = diffusion_times .- step_size
next_noisy_rates, next_signal_rates = diffusion_schedules(
next_diffusion_times, min_signal_rate, max_signal_rate
)
@. noisy_images = next_signal_rates .* pred_images .+ next_noisy_rates .* pred_noises
return nothing
end
function reverse_diffusion(
model::DDIM, initial_noise::AbstractArray{T,4}, ps, st::NamedTuple, diffusion_steps::Int
) where {T}
step_size = one(T) / diffusion_steps
ones_dev = ones_like(initial_noise, (1, 1, 1, size(initial_noise, 4)))
noisy_images, pred_images = initial_noise, similar(initial_noise)
@trace for step in 1:diffusion_steps
reverse_diffusion_single_step!(
pred_images,
noisy_images,
step,
step_size,
model.unet,
ps.unet,
st.unet,
ones_dev,
model.min_signal_rate,
model.max_signal_rate,
)
end
return pred_images
end
function denormalize(x::AbstractArray{T,4}, μ, σ², ε) where {T}
μ = reshape(μ, 1, 1, 3, 1)
σ = sqrt.(reshape(σ², 1, 1, 3, 1) .+ T(ε))
return σ .* x .+ μ
end
function create_image_list(imgs::AbstractArray)
return map(eachslice(imgs; dims=4)) do img
cimg = if size(img, 3) == 1
colorview(Gray, view(img, :, :, 1))
else
colorview(RGB, permutedims(img, (3, 1, 2)))
end
return cimg'
end
end
function create_image_grid(
images::AbstractArray, grid_rows::Int, grid_cols::Union{Int,Nothing}=nothing
)
images = ndims(images) != 1 ? create_image_list(images) : images
grid_cols = grid_cols === nothing ? length(images) ÷ grid_rows : grid_cols
# Check if the number of images matches the grid
total_images = grid_rows * grid_cols
@assert length(images) ≤ total_images
# Get the size of a single image (assuming all images are the same size)
img_height, img_width = size(images[1])
# Create a blank grid canvas
grid_height = img_height * grid_rows
grid_width = img_width * grid_cols
grid_canvas = similar(images[1], grid_height, grid_width)
# Place each image in the correct position on the canvas
for idx in 1:total_images
idx > length(images) && break
row = div(idx - 1, grid_cols) + 1
col = mod(idx - 1, grid_cols) + 1
start_row = (row - 1) * img_height + 1
start_col = (col - 1) * img_width + 1
grid_canvas[start_row:(start_row + img_height - 1), start_col:(start_col + img_width - 1)] .= images[idx]
end
return grid_canvas
end
function save_images(output_dir, images::Vector{<:AbstractMatrix{<:RGB}})
for (i, img) in enumerate(images)
if any(isnan, img)
@warn "NaN image found in the generated images. Skipping..."
continue
end
save(joinpath(output_dir, "img_$(i).png"), img)
end
endDataset
We will register the dataset using the DataDeps.jl package. The dataset is available at https://www.robots.ox.ac.uk/~vgg/data/flowers/102/. This allows us to automatically download the dataset when we run the code.
julia
@concrete struct FlowersDataset
image_size
image_files
transform
end
FlowersDataset(image_size::Int) = FlowersDataset((image_size, image_size))
function FlowersDataset(image_size::Dims{2})
dirpath = try
joinpath(datadep"FlowersDataset", "jpg")
catch err
err isa KeyError || rethrow()
register(
DataDep(
"FlowersDataset",
"102 Category Flowers Dataset",
"https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz",
"2d01ecc807db462958cfe3d92f57a8c252b4abd240eb955770201e45f783b246";
post_fetch_method=file -> run(`tar -xzf $file`),
),
)
joinpath(datadep"FlowersDataset", "jpg")
end
image_files = joinpath.(dirpath, readdir(dirpath))
transform =
ScaleKeepAspect(image_size) |>
CenterResizeCrop(image_size) |>
Maybe(FlipX{2}()) |>
ImageToTensor()
return FlowersDataset(image_size, image_files, transform)
end
Base.length(ds::FlowersDataset) = length(ds.image_files)
function Base.getindex(ds::FlowersDataset, i::Int)
return Float32.(itemdata(apply(ds.transform, Image(load(ds.image_files[i])))))
end
function Base.getindex(ds::FlowersDataset, idxs)
imgs = Array{Float32,4}(undef, ds.image_size..., 3, length(idxs))
tforeach(1:length(idxs)) do i
img = Image(load(ds.image_files[idxs[i]]))
return copyto!(view(imgs, :, :, :, i), itemdata(apply(ds.transform, img)))
end
return imgs
end
function loss_function(model, ps, st, data)
(noises, images, pred_noises, pred_images), stₙ = Lux.apply(model, data, ps, st)
noise_loss = MSELoss()(pred_noises, noises)
image_loss = MSELoss()(pred_images, images)
return noise_loss, stₙ, (; image_loss, noise_loss)
endEntry Point for our code
julia
function main(;
epochs::Int=100,
image_size::Int=128,
batchsize::Int=128,
learning_rate_start::Float32=1.0f-3,
learning_rate_end::Float32=1.0f-5,
weight_decay::Float32=1.0f-6,
checkpoint_interval::Int=25,
expt_dir="",
diffusion_steps::Int=80,
generate_image_interval::Int=5,
# model hyper params
channels::Vector{Int}=[32, 64, 96, 128],
block_depth::Int=2,
min_freq::Float32=1.0f0,
max_freq::Float32=1000.0f0,
embedding_dims::Int=32,
min_signal_rate::Float32=0.02f0,
max_signal_rate::Float32=0.95f0,
# inference specific
inference_mode::Bool=false,
saved_model_path=nothing,
generate_n_images::Int=12,
)
if isempty(expt_dir)
expt_dir =
joinpath(@__DIR__, string(now(UTC)) * "_" * uppercase(randstring(4))) * "_ddim"
end
isdir(expt_dir) || mkpath(expt_dir)
@printf "[%s] [Info] Experiment directory: %s\n" now(UTC) expt_dir
rng = Random.default_rng()
Random.seed!(rng, 1234)
image_dir = joinpath(expt_dir, "images")
isdir(image_dir) || mkpath(image_dir)
ckpt_dir = joinpath(expt_dir, "checkpoints")
isdir(ckpt_dir) || mkpath(ckpt_dir)
xdev = reactant_device(; force=true)
cdev = cpu_device()
@printf "[%s] [Info] Building model\n" now(UTC)
model = DDIM(
(image_size, image_size);
channels,
block_depth,
min_freq,
max_freq,
embedding_dims,
min_signal_rate,
max_signal_rate,
)
if inference_mode
@assert saved_model_path !== nothing "`saved_model_path` must be specified for \
inference"
@load saved_model_path parameters states
ps, st = (parameters, states) |> xdev
generated_images = @jit generate(
model, ps, Lux.testmode(st), diffusion_steps, generate_n_images
)
generated_images = generated_images |> cdev
path = joinpath(image_dir, "inference")
@printf "[%s] [Info] Saving generated images to %s\n" now(UTC) path
imgs = create_image_list(generated_images)
save_images(path, imgs)
if IN_VSCODE
display(create_image_grid(generated_images, 8, cld(length(imgs), 8)))
end
return nothing
end
ps, st = Lux.setup(rng, model) |> xdev
tb_dir = joinpath(expt_dir, "tb_logs")
@printf "[%s] [Info] Tensorboard logs being saved to %s. Run tensorboard with \
`tensorboard --logdir %s`\n" now(UTC) tb_dir dirname(tb_dir)
tb_logger = TBLogger(tb_dir)
opt = AdamW(; eta=learning_rate_start, lambda=weight_decay)
scheduler = CosAnneal(learning_rate_start, learning_rate_end, epochs)
tstate = Training.TrainState(model, ps, st, opt)
@printf "[%s] [Info] Preparing dataset\n" now(UTC)
ds = FlowersDataset(image_size)
data_loader =
DataLoader(ds; batchsize, shuffle=true, partial=false, parallel=false) |> xdev
pt = ProgressTable(;
header=["Epoch", "Image Loss", "Noise Loss", "Time (s)", "Throughput (img/s)"],
widths=[10, 24, 24, 24, 24],
format=["%3d", "%.6f", "%.6f", "%.6f", "%.6f"],
color=[:normal, :normal, :normal, :normal, :normal],
border=true,
alignment=[:center, :center, :center, :center, :center],
)
@printf "[%s] [Info] Compiling generate function\n" now(UTC)
time_start = time()
generate_compiled = @compile generate(
model, ps, Lux.testmode(st), diffusion_steps, generate_n_images
)
@printf "[%s] [Info] Compiled generate function in %.6f seconds\n" now(UTC) (
time() - time_start
)
image_losses = Vector{Float32}(undef, length(data_loader))
noise_losses = Vector{Float32}(undef, length(data_loader))
step = 1
@printf "[%s] [Info] Training model\n" now(UTC)
initialize(pt)
for epoch in 1:epochs
total_time = 0.0
total_samples = 0
eta = Float32(scheduler(epoch))
tstate = Optimisers.adjust!(tstate, eta)
log_value(tb_logger, "Training/Learning Rate", eta; step)
start_time = time()
for (i, data) in enumerate(data_loader)
(_, loss, stats, tstate) = Training.single_train_step!(
AutoEnzyme(), loss_function, data, tstate; return_gradients=Val(false)
)
@assert !isnan(loss) "NaN loss ($(loss)) encountered!"
total_samples += size(data, ndims(data))
image_losses[i] = stats.image_loss
noise_losses[i] = stats.noise_loss
log_value(tb_logger, "Training/Image Loss", Float32(stats.image_loss); step)
log_value(tb_logger, "Training/Noise Loss", Float32(stats.noise_loss); step)
log_value(
tb_logger,
"Training/Throughput",
total_samples / (time() - start_time);
step,
)
step += 1
end
total_time = time() - start_time
next(
pt,
[
epoch,
mean(image_losses),
mean(noise_losses),
total_time,
total_samples / total_time,
],
)
if epoch % generate_image_interval == 0 || epoch == epochs
generated_images = generate_compiled(
tstate.model,
tstate.parameters,
Lux.testmode(tstate.states),
diffusion_steps,
generate_n_images,
)
generated_images = generated_images |> cdev
path = joinpath(image_dir, "epoch_$(epoch)")
imgs = create_image_list(generated_images)
save_images(path, imgs)
log_images(tb_logger, "Generated Images", imgs; step)
if IN_VSCODE
display(create_image_grid(generated_images, 8, cld(length(imgs), 8)))
end
end
if epoch % checkpoint_interval == 0 || epoch == epochs
path = joinpath(ckpt_dir, "model_$(epoch).jld2")
@printf "[%s] [Info] Saving checkpoint to %s\n" now(UTC) path
parameters = tstate.parameters |> cdev
states = tstate.states |> cdev
@save path parameters states
end
end
finalize(pt)
@printf "[%s] [Info] Saving final model\n" now(UTC)
return tstate
end
function get_argparse_settings()
s = ArgParseSettings(; autofix_names=true)
#! format: off
@add_arg_table s begin
"--epochs"
help = "Number of epochs to train"
arg_type = Int
default = 100
"--image-size"
help = "Input image size (square)"
arg_type = Int
default = 128
"--batchsize"
help = "Training batch size"
arg_type = Int
default = 128
"--learning-rate-start"
help = "Starting learning rate"
arg_type = Float32
default = 3.0f-3
"--learning-rate-end"
help = "Final learning rate"
arg_type = Float32
default = 1.0f-4
"--weight-decay"
help = "Weight decay (AdamW lambda)"
arg_type = Float32
default = 1.0f-6
"--checkpoint-interval"
help = "Save checkpoint every N epochs"
arg_type = Int
default = 25
"--expt-dir"
help = "Experiment output directory"
arg_type = String
default = ""
"--diffusion-steps"
help = "Number of DDIM reverse diffusion steps"
arg_type = Int
default = 80
"--generate-image-interval"
help = "Generate and log images every N epochs"
arg_type = Int
default = 5model hyper params
julia
"--channels"
help = "UNet channels per stage"
arg_type = Int
nargs = '+'
default = [32, 64, 96, 128]
"--block-depth"
help = "Number of residual blocks per stage"
arg_type = Int
default = 2
"--min-freq"
help = "Sinusoidal embedding min frequency"
arg_type = Float32
default = 1.0f0
"--max-freq"
help = "Sinusoidal embedding max frequency"
arg_type = Float32
default = 1000.0f0
"--embedding-dims"
help = "Sinusoidal embedding dimension"
arg_type = Int
default = 32
"--min-signal-rate"
help = "Minimum signal rate"
arg_type = Float32
default = 0.02f0
"--max-signal-rate"
help = "Maximum signal rate"
arg_type = Float32
default = 0.95f0inference specific
julia
"--inference"
help = "Run in inference-only mode"
action = :store_true
"--saved-model-path"
help = "Path to JLD2 checkpoint (required with --inference)"
arg_type = String
"--generate-n-images"
help = "Number of images to generate during inference or periodic logging"
arg_type = Int
default = 12
end
#! format: on
return s
end
if abspath(PROGRAM_FILE) == @__FILE__
args = parse_args(ARGS, get_argparse_settings(); as_symbols=true)
main(;
epochs=args[:epochs],
image_size=args[:image_size],
batchsize=args[:batchsize],
learning_rate_start=args[:learning_rate_start],
learning_rate_end=args[:learning_rate_end],
weight_decay=args[:weight_decay],
checkpoint_interval=args[:checkpoint_interval],
expt_dir=args[:expt_dir],
diffusion_steps=args[:diffusion_steps],
generate_image_interval=args[:generate_image_interval],
channels=args[:channels],
block_depth=args[:block_depth],
min_freq=args[:min_freq],
max_freq=args[:max_freq],
embedding_dims=args[:embedding_dims],
min_signal_rate=args[:min_signal_rate],
max_signal_rate=args[:max_signal_rate],
inference_mode=args[:inference],
saved_model_path=args[:saved_model_path],
generate_n_images=args[:generate_n_images],
)
endThis page was generated using Literate.jl.