连续时间马尔可夫链在离散扩散模型中的应用与实现
1. 从“离散”到“连续”为什么我们需要连续时间马尔可夫链如果你接触过图像生成大概率听说过扩散模型。从Stable Diffusion到DALL-E这些模型通过逐步向图片添加噪声再学习如何逆向去噪从而生成逼真的图像。但你是否想过如果我们的数据不是连续的图像像素而是离散的类别呢比如文本中的一个词、蛋白质序列中的一个氨基酸、或者分子结构中的一个原子类型。这时传统的基于高斯噪声的连续扩散模型就有点“水土不服”了。这就是“离散扩散模型”要解决的问题。它的核心思想很直观对于一个离散状态比如词汇表中的某个词我们不是添加微小的连续噪声而是以一定的概率将其“跳变”到其他状态比如变成另一个词或者一个特殊的“[MASK]”标记。这个过程可以用一个离散时间的马尔可夫链来描述在每一步状态都根据一个转移矩阵发生随机变化。然而离散时间模型有个天然的“缺陷”时间步是离散的、固定的。我们得预先设定好要扩散多少步比如1000步每一步的噪声强度转移概率也需要精心设计一个调度表。这带来了几个麻烦首先采样速度受限于步数想快也快不了其次设计这个调度表本身就是一个需要调参的玄学问题最后离散的步骤让理论分析变得不那么优雅。于是一个自然的想法出现了能不能让这个跳变过程在“连续的时间”里发生就像观察一杯清水滴入墨水墨水的扩散是一个在时间上连续的过程而不是一秒跳一下。这就是“基于连续时间马尔可夫链的离散扩散模型”的出发点。它将离散状态在连续时间轴上的随机演化过程形式化用一组微分方程来描述状态概率分布的变化。这样做的好处是革命性的采样过程可以借助高效的数值ODE常微分方程求解器实现任意步数的快速采样噪声调度在连续时间下称为“速率函数”的设计有了更坚实的理论基础并且整个框架在数学上更加统一和优美。最近网络上的技术热词无论是yolov8训练自己的数据集还是resnet预训练模型都体现了大家对“如何高效训练和利用模型”的持续关注。而像adc采样、同步采样原理这类硬件相关的热词则从另一个维度强调了“采样”这一行为在信号处理中的核心地位。将“连续时间”的思想引入离散扩散正是为了给离散数据的“生成采样”找到一条像ODE求解那样高效、可控的新路径。它不是为了取代已有的离散扩散而是为其提供了一个更强大、更灵活的理论和计算框架。2. 核心原理拆解连续时间马尔可夫链如何驱动离散扩散要理解这个模型我们需要先抛开复杂的公式从直观上把握两个核心概念状态空间和转移速率。想象一个非常简单的例子我们建模的数据是二进制信号每个位置只能是0或1。这就是一个离散状态空间只有两个状态。在离散时间扩散里我们可能规定每一步每个比特有10%的概率翻转0变1或1变0。在连续时间框架下我们不再说“每一步”的概率而是定义转移速率。比如我们可以定义从状态0跳转到状态1的瞬时速率为 β(t)从1跳转到0的速率也是 β(t)。这里的 β(t) 是一个关于连续时间 t 的函数它控制了噪声注入的强度随时间如何变化。2.1 前向扩散过程连续时间的“加噪”在连续时间马尔可夫链CTMC的设定下前向扩散过程被描述为一个随机过程从真实数据分布在t0时刻开始随着时间t从0向更大的T增长数据状态根据定义好的转移速率矩阵随机跳变。这个过程的数学核心是科尔莫戈罗夫向前方程又称主方程。它本质上是一个微分方程描述了在任意时刻t数据处于各个离散状态的概率分布是如何随时间演化的。具体来说如果我们用向量 p(t) 来表示在时刻t处于各个状态的概率分布那么主方程可以写作 dp(t)/dt Q(t)^T p(t) 这里的 Q(t) 就是速率矩阵。Q(t) 的非对角线元素 Q_{ij}(t) (i≠j) 就表示从状态i跳转到状态j的瞬时速率。对角线元素 Q_{ii}(t) 则为负的跳出速率之和以保证每行之和为0。这个方程告诉我们概率分布的变化率等于当前分布左乘速率矩阵的转置。注意这里有一个关键但容易混淆的点。在连续时间扩散模型中我们通常设定一个“先验分布”比如一个均匀分布或一个吸收态例如全[MASK]。前向过程的目标是当时间t足够大t→T时无论初始数据是什么其分布都会演化到这个简单的先验分布。速率函数 β(t) 的设计就是为了保证这一点。2.2 逆向生成过程学习去噪的“漂移”生成采样是我们的终极目标。既然前向过程把数据变成了噪声先验分布那么如果我们能逆转这个过程就能从噪声中生成数据。幸运的是对于CTMC描述的扩散过程理论上存在一个对应的逆向时间过程它也是一个连续时间马尔可夫链。这个逆向过程的速率矩阵依赖于前向过程的速率矩阵以及一个关键的量在给定未来时刻状态的情况下当前时刻状态的条件概率。这个条件概率正是我们需要神经网络去学习的目标通常我们定义一个模型比如一个Transformer输入是t时刻的带噪数据 x_t输出是对所有可能状态的一个评分logits这个评分经过softmax后就模拟了逆向过程所需的条件概率分布。因此逆向生成过程可以这样进行我们从先验分布tT中随机采样一个初始“噪声”状态然后沿着时间t从T回溯到0求解一个关于逆向过程概率流的微分方程。在这个过程中神经网络预测的条件概率被用来计算逆向的“漂移”项引导概率分布逐渐从先验变回真实数据分布。2.3 与离散时间模型的对比为了更清晰地看到连续时间框架的优势我们可以将其与经典离散时间扩散模型做一个对比特性维度离散时间扩散模型基于CTMC的连续时间扩散模型时间域离散的步数 {0, 1, 2, ..., N}连续的区间 [0, T]噪声过程每一步应用一个转移矩阵由连续的速率矩阵 Q(t) 定义核心方程递推关系p_{k1} p_k * P_k微分方程主方程dp/dt Q(t)^T p采样灵活性必须按固定步数顺序执行可使用ODE求解器支持自适应步长、快速采样调度设计需要为每个离散步设计转移概率只需设计连续的速率函数 β(t)更灵活且易于分析理论统一性相对独立与连续数据扩散模型SDE/ODE共享更统一的数学框架这种连续化的表述使得我们可以借鉴在连续扩散模型中已经非常成熟的加速采样技术比如DDIM、DPM-Solver等思想将其适配到离散状态空间从而实现数量级上的采样提速。3. 模型训练如何教会网络预测“逆向条件概率”训练是整个模型的核心环节目标是得到一个能够准确预测逆向过程所需条件概率的神经网络。这里最常用的方法是基于变分推断的损失函数也称为去噪分数匹配在离散空间上的一个变体。3.1 训练目标函数的推导损失函数的设计直观上是为了让模型预测的条件概率分布与真实的前向过程“后验分布”尽可能接近。所谓后验分布是指如果我们已知在稍晚的某个时刻ss t的数据状态x_s那么它在较早时刻t的真实状态x_t的概率分布是怎样的通过一番数学推导利用贝叶斯定理和CTMC的性质我们可以得到一个相对简洁的损失函数形式。对于单个数据样本x_0干净数据在随机采样一个时间点t和该时间点对应的带噪状态x_t后损失函数通常可以表示为一种加权的交叉熵损失L E_{t, x_t} [ w(t) * CE( model(x_t, t), x_0 ) ]这里t是从时间区间[0, T]中均匀或按某种重要性分布采样得到的。x_t是通过模拟前向过程从x_0在时间t演化得到的一个随机样本。model(x_t, t)是神经网络输出的logits经过softmax后得到对各个状态预测的概率分布。CE是交叉熵损失衡量模型预测分布与“目标”分布之间的差异。w(t)是一个与时间相关的权重函数通常用于平衡不同时间点损失的重要性例如更关注中间时间点。这里有一个极其关键的细节目标分布是什么一个最直接的想法是让模型直接预测原始的干净数据x_0。这在很多情况下是有效的被称为“x_0预测参数化”。然而对于某些离散扩散过程特别是那些有吸收态如[MASK]的直接预测x_0在训练初期可能非常困难。因此另一种更稳定、更常用的参数化方式是预测**“去噪后的数据分布”**或者说是预测在给定x_t和t的情况下x_0的后验期望。在代码实现中这通常体现为让模型输出一个与x_0同维度的logits其训练目标就是让这个logits经过softmax后与x_0的one-hot向量的交叉熵最小。3.2 训练中的实用技巧与坑点在实际训练中有几个点需要特别注意这些往往是论文不会细说但实践中却能决定成败的“暗坑”。1. 时间步的采样策略时间t不能真的从[0, T]均匀采样。因为在t接近0时数据几乎没被污染去噪任务太简单在t接近T时数据已完全变成先验噪声去噪任务几乎不可能且对最终生成质量贡献小。因此需要采用重要性采样。一种常见的策略是从一个偏向中间时间点的分布中采样t例如采用对数正态分布或者简单地在时间域进行平方或立方采样即采样 u ~ Uniform[0,1]然后令 t T * u^2。这能确保模型将更多的学习容量分配给具有挑战性且重要的中等噪声水平阶段。2. 损失权重的选择权重函数w(t)的选择对生成质量有微妙影响。w(t) 1是一种朴素选择。但研究表明类似于连续扩散模型中的“信噪比”加权在离散扩散中设置一个与“前向过程信噪比”成反比的权重往往能取得更均衡的结果。这需要根据你选定的速率函数β(t)进行推导。一个实用的起点是尝试w(t) 1 / (预期噪声比例)然后根据验证集上的生成质量进行微调。3. 速率函数β(t)的设计这是连续时间离散扩散模型的“超参数”相当于离散模型的噪声调度表。常见的选择有线性调度β(t) β_min (β_max - β_min) * (t / T)。简单但可能不是最优。余弦调度借鉴连续扩散令信噪比按余弦函数衰减通常能获得更平滑的过渡和更好的效果。学习得到的调度将β(t)参数化为一个小型神经网络与主模型一起学习。这潜力最大但增加了训练复杂度和不稳定性。对于大多数初次尝试从余弦调度开始是一个稳健的选择。4. 掩码策略与吸收态对于文本等序列数据前向过程常常设计为以一定速率将词元替换为一个特殊的[MASK]标记吸收态。在连续时间框架下这意味着向[MASK]状态的转移速率不为零而从[MASK]跳出的速率为零一旦被掩码就停留在那里。这种设计简化了先验分布最终全部是[MASK]但也带来了挑战模型在生成后期需要“无中生有”地预测出被掩码的词。训练时需要确保损失函数能正确处理这种非对称的转移。4. 采样算法详解从理论ODE到实际代码生成训练好的模型只是一个概率分布预测器。如何利用它从先验噪声中一步步“雕刻”出最终的数据样本就是采样算法的任务。连续时间框架的魅力在此展露无遗。4.1 概率流ODE与求解器我们已经知道逆向过程的演化也服从一个微分方程即概率流常微分方程PF-ODE。这个方程的形式是 dx_t / dt f(x_t, t) 这里的f是一个“漂移”项它由前向速率矩阵 Q(t) 和神经网络预测的条件概率或得分共同决定。具体表达式依赖于你所采用的具体参数化方式预测x_0还是预测得分。一旦有了这个ODE采样就变成了一个数值求解问题我们从 t T 时刻从先验分布例如所有词都是[MASK]中采样一个初始状态 x_T然后使用一个ODE求解器沿着时间从T积分到0最终得到 x_0即生成的样本。为什么这比离散采样快在离散时间模型中你必须严格地执行N步比如1000步每一步都要调用一次模型。在连续时间ODE求解中你可以使用高阶自适应步长求解器如Runge-Kutta方法或DPM-Solver。这些求解器可以根据曲线局部的“平滑度”动态调整步长。在变化平缓的区域例如生成后期细节微调它可以迈出很大的步长在变化剧烈的区域例如生成中期主体结构形成它会自动缩小步长以保证精度。这意味着可能只需要20-50次模型评估NFE就能达到原来1000步的效果实现了10-50倍的加速。4.2 几种实用的采样方案1. 欧拉法最简单的ODE求解器这相当于将连续时间离散化是最直接的实现方式。步骤是设置总时间T计划步数N例如20步。计算时间步长 Δt T / N。从先验分布采样 x_N。fori from N to 1:t i * Δt根据当前状态 x_t 和时间 t用模型计算漂移项 f(x_t, t)更新状态x_{t-Δt} x_t - f(x_t, t) * Δt 注意符号逆向时间是倒退的得到 x_0。这种方法简单但精度较低可能需要较多的步数如100步才能保证质量。2. Heun法二阶ODE求解器这是欧拉法的改进版通过多计算一次模型来获得更精确的梯度估计属于预测-校正类方法。步骤大致为预测步计算f_t f(x_t, t)得到预测状态x_t_p x_t - Δt * f_t。校正步在预测状态处再计算梯度f_{t-Δt} f(x_t_p, t-Δt)。使用平均梯度更新x_{t-Δt} x_t - Δt * (f_t f_{t-Δt}) / 2。 Heun法在每一步需要两次模型评估但精度更高通常可以用更少的步数达到相同效果。3. 基于得分的采样器如DPM-Solver适配对于预测“得分”即对数概率的梯度的模型参数化方式可以专门适配DPM-Solver这类为扩散模型设计的高阶求解器。DPM-Solver利用了扩散过程ODE解的特殊结构通过指数积分器来实现更高阶的精度。它的实现比Heun法复杂但通常能在10-20步内达到极佳的采样质量是目前SOTA方法的首选。实操心得在项目初期强烈建议从简单的欧拉法开始实现以确保整个采样流程正确。在验证了流程和模型的基本有效性后再尝试集成更高效的Heun法或DPM-Solver。你可以将不同的求解器封装成可插拔的模块方便后续对比和调优。4.3 处理离散状态的挑战Straight-Through技巧在ODE求解中状态 x_t 理论上是一个连续的概率分布向量各个状态的概率。但在实际迭代中我们通常需要将其“物化”为一个具体的离散状态才能输入到神经网络中因为网络通常接受离散的token ID或one-hot向量。这里的一个常见技巧是Straight-Through Estimator (STE)。具体做法是在每一步ODE求解后我们得到的是一个连续的概率分布向量 p_t。为了得到下一个时刻的输入状态我们从 p_t 中采样一个具体的离散状态 x_t例如根据概率进行多项式采样。然而采样操作是不可导的会阻断梯度回传。STE的做法是在反向传播时假装采样操作就是直接选择了概率最大的那个状态argmax或者说直接使用 p_t 的softmax logits作为离散状态的“连续近似”来通过梯度。在下一个前向传播中我们依然使用采样得到的离散状态 x_t。这种方法在实践中被证明是有效的它允许梯度通过连续的概率分布进行流动同时保持了采样过程的随机性。在代码中这通常通过torch.where或detach()等操作来实现。5. 实战构建一个简单的文本字符生成模型理论说了这么多我们动手实现一个最小化的例子来直观感受整个过程。我们将构建一个模型学习生成简单的、固定长度的字符串比如5个字符每个字符取自字母表a-z。5.1 环境与数据准备我们使用PyTorch框架。首先定义一些常量import torch import torch.nn as nn import torch.nn.functional as F import numpy as np # 超参数 vocab_size 26 # a-z seq_len 5 hidden_dim 128 num_layers 3 time_embed_dim 32 T 1.0 # 总扩散时间 # 速率函数简单的线性调度 def beta(t): return 0.1 1.9 * t # beta从0.1线性增加到2.0 # 前向过程转移概率计算给定时间t和初始状态x0求xt的分布 def compute_qt(x0, t): # x0: [batch_size, seq_len]值为0-25的整数 # t: [batch_size, 1] 或标量 # 返回xt的logits [batch_size, seq_len, vocab_size] rate beta(t) if torch.is_tensor(t) else beta(t) # 构造速率矩阵Q从任意状态i到其他状态j的速率都是 rate/(vocab_size-1) 跳出速率总和为rate # 这里简化处理使用一个均匀转移矩阵 # 计算转移概率矩阵 Pt expm(Q * t) 对于均匀转移有解析解 prob_remain torch.exp(-rate * t) # 停留在原状态的概率 prob_transfer (1 - prob_remain) / (vocab_size - 1) # 转移到其他任一状态的概率 batch_size x0.size(0) # 初始化logits为转移概率 logits torch.full((batch_size, seq_len, vocab_size), fill_valueprob_transfer) # 为每个位置、每个样本将对应x0状态的概率设为prob_remain # 这里需要一些张量操作技巧 x0_one_hot F.one_hot(x0, num_classesvocab_size).float() # [B, L, V] logits logits (prob_remain - prob_transfer) * x0_one_hot return torch.log(logits 1e-8) # 返回log概率5.2 神经网络模型设计我们的模型需要接受带噪的离散序列x_t和时间嵌入t输出每个位置下一个状态的logits。我们使用一个简单的Transformer编码器。class TimeEmbedding(nn.Module): 将连续时间t映射为向量 def __init__(self, dim): super().__init__() self.dim dim half_dim dim // 2 emb np.log(10000) / (half_dim - 1) emb torch.exp(torch.arange(half_dim, dtypetorch.float) * -emb) self.register_buffer(emb, emb) def forward(self, t): # t: [batch_size, 1] emb t * self.emb emb torch.cat([torch.sin(emb), torch.cos(emb)], dim-1) if self.dim % 2 1: # 如果dim是奇数补零 emb F.pad(emb, (0, 1)) return emb # [batch_size, dim] class DiscreteDiffusionModel(nn.Module): def __init__(self, vocab_size, seq_len, hidden_dim, time_embed_dim): super().__init__() self.vocab_size vocab_size self.seq_len seq_len self.token_embed nn.Embedding(vocab_size, hidden_dim) self.time_embed TimeEmbedding(time_embed_dim) self.time_proj nn.Linear(time_embed_dim, hidden_dim) # 简单的Transformer编码器 encoder_layer nn.TransformerEncoderLayer( d_modelhidden_dim, nhead8, dim_feedforwardhidden_dim*4, batch_firstTrue, dropout0.1 ) self.transformer nn.TransformerEncoder(encoder_layer, num_layers3) self.output_layer nn.Linear(hidden_dim, vocab_size) def forward(self, x, t): # x: [batch_size, seq_len] 离散token索引 # t: [batch_size, 1] 时间 token_emb self.token_embed(x) # [B, L, H] time_emb self.time_embed(t) # [B, D_t] time_emb self.time_proj(time_emb).unsqueeze(1) # [B, 1, H] # 将时间嵌入加到每个token上 x_emb token_emb time_emb # 通过Transformer # 注意对于简单的字符级任务我们不需要因果掩码使用全注意力即可 transformer_out self.transformer(x_emb) # 预测每个位置的logits logits self.output_layer(transformer_out) # [B, L, V] return logits5.3 训练循环核心代码训练循环包括采样时间、模拟前向过程、计算损失。def train_step(model, optimizer, data_batch): data_batch: [batch_size, seq_len] 每个元素是0-25的整数 model.train() batch_size data_batch.size(0) # 1. 采样时间t t torch.rand((batch_size, 1), devicedata_batch.device) * T # 均匀采样可改为重要性采样 # 2. 模拟前向过程得到带噪样本x_t # 计算给定x0和t时xt的log概率分布 log_prob_xt_given_x0 compute_qt(data_batch, t) # [B, L, V] # 从该分布中采样具体的xt使用Gumbel-Softmax或直接多项式采样 # 使用Gumbel-Softmax以获得可微的采样训练时 xt F.gumbel_softmax(log_prob_xt_given_x0, tau1.0, hardTrue) # [B, L, V] one-hot # 将one-hot转换为token索引用于嵌入查找Straight-Through xt_tokens torch.argmax(xt, dim-1).detach() # 前向使用离散token xt_one_hot F.one_hot(xt_tokens, num_classesvocab_size).float() # 用于后续计算 # 3. 模型前向传播 pred_logits model(xt_tokens, t) # 模型接收离散token索引 # 4. 计算损失预测x0的交叉熵 # 目标让模型预测的分布接近真实的x0 loss F.cross_entropy( pred_logits.reshape(-1, vocab_size), data_batch.reshape(-1) ) # 5. 反向传播与优化 optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # 梯度裁剪 optimizer.step() return loss.item()5.4 采样欧拉法实现torch.no_grad() def euler_sampling(model, num_steps20): 使用欧拉法进行采样 model.eval() batch_size 4 # 生成4个样本 device next(model.parameters()).device # 1. 初始化从先验分布采样。这里先验是均匀分布。 # 更常见的文本扩散先验是全部为[MASK] token这里简化使用均匀分布。 x torch.randint(0, vocab_size, (batch_size, seq_len), devicedevice) # [B, L] dt T / num_steps ts torch.linspace(T, 0, num_steps1, devicedevice) # 从T到0 for i in range(num_steps): t_cur ts[i].unsqueeze(0).unsqueeze(0) # [1, 1] t_cur t_cur.expand(batch_size, -1) # [B, 1] # 2. 模型预测当前状态的logits pred_logits model(x, t_cur) # [B, L, V] pred_probs F.softmax(pred_logits, dim-1) # 3. 计算“漂移”项 f(x,t)。 # 简化版本假设逆向过程倾向于向模型预测的x0移动。 # 一种近似f ≈ (pred_probs - one_hot(x)) * beta(t) 这里beta(t)是速率。 rate beta(t_cur[0,0].item()) # 将当前状态x转为one-hot x_one_hot F.one_hot(x, num_classesvocab_size).float() # [B, L, V] # 计算漂移这里是一个启发式公式实际ODE推导更复杂 drift rate * (pred_probs - x_one_hot) # [B, L, V] # 4. 欧拉更新x_{t-dt} x_t - drift * dt 注意符号逆向时间 # 我们需要将drift作用在概率分布上然后采样新状态。 # 更新概率分布p_new x_one_hot - drift * dt p_new x_one_hot - drift * dt # 确保概率合法 p_new torch.clamp(p_new, min0) p_new p_new / p_new.sum(dim-1, keepdimTrue) # 5. 从新分布中采样下一个状态使用Straight-Through # 训练时用Gumbel-Softmax推理时直接多项式采样 x torch.multinomial(p_new.view(-1, vocab_size), 1).view(batch_size, seq_len) # 将token索引转换为字符 idx_to_char {i: chr(ord(a)i) for i in range(26)} generated_strings [] for seq in x.cpu().numpy(): chars [idx_to_char[idx] for idx in seq] generated_strings.append(.join(chars)) return generated_strings这个简化实例涵盖了从数据准备、模型定义、训练到采样的核心流程。在实际应用中你需要根据具体任务如自然语言文本、代码、生物序列设计更合适的网络结构如因果Transformer、BERT等、更精确的速率函数和更高效的采样器。6. 进阶话题与未来方向掌握了基本原理和实现后我们可以看看这个领域正在探索的一些前沿方向和待解决的挑战。1. 条件生成与控制如何让模型生成符合特定条件的内容例如给定一个情感标签生成相应情绪的文本或者根据分子属性生成特定结构的分子。主流方法是在训练时引入条件信息如标签、描述文本的嵌入在采样时通过分类器指导或无分类器指导来引导生成过程。在连续时间框架下这通常意味着在ODE的漂移项中加入一个条件梯度的加权项以增大生成样本符合目标条件的概率。2. 快速采样算法的极限虽然ODE求解器已经大大加速了采样但对于大规模模型如数十亿参数的文本扩散模型每一步的模型评估开销依然巨大。研究更高效的、步数更少的求解器如将步数压缩到10步以内是一个热点。此外一致性模型的思想也被引入离散扩散旨在学习一个能将任意噪声点直接映射到数据点的网络实现一步或极少步生成。3. 与其他生成模型的融合离散扩散模型与自回归模型、流模型等如何结合一个思路是分层扩散在粗粒度上进行扩散生成大纲然后在细粒度上自回归或扩散生成细节。另一个思路是混合训练让模型同时学习扩散和自回归目标以兼顾生成速度和质量。4. 复杂结构数据的应用当前研究已不再局限于一维序列。图结构数据分子图、社交网络、二维网格数据图像离散编码、三维结构蛋白质构象的离散扩散模型正在兴起。这些场景需要设计符合其对称性平移、旋转、置换不变性的转移速率矩阵和网络结构挑战更大但应用前景也更广阔。5. 速率函数与噪声调度的自动化学习如前所述速率函数β(t)是一个关键的超参数。让模型自己学习最优的噪声调度是另一个减少人工干预、提升性能的方向。这可以通过将β(t)参数化并与其他参数一起优化或者通过元学习的方式来实现。从我个人的实验经验来看连续时间离散扩散模型最大的优势在于其灵活性和效率。一旦你搭建好了这个框架更换不同的速率函数、尝试不同的ODE求解器、或者引入条件控制都变得模块化且相对容易。它就像为离散数据生成提供了一个强大的“数学操作系统”在此之上可以构建各种各样的应用。当然初期的调试可能会有些棘手尤其是损失函数不稳定或采样质量不佳时需要耐心地检查梯度、调整学习率、以及可视化中间生成过程。但一旦跑通你会发现它是一条非常优雅且强大的生成建模路径。