⚠️⚠️⚠️本文为稀土掘金技术社区首发签约文章,30天内禁止转载,30天后未获授权禁止转载,侵权必究!
终于来到扩散模型DDPM系列的最后一篇:源码解读了。本文将配合详细的图例,来为大家解读DDPM的模型架构与训练方式的代码实现。
DDPM原作的github地址在此,采用tensorflow进行实现。本文讲解代码选择的github在此,采用pytorch进行实现。
之所以没有选择原作的github进行讲解,主要基于以下原因:
-
pytorch的受众面更广。在保证模型效果复现的基础上,使用tf或pytorch进行讲解差别不大。tf技术栈的朋友们,也可以利用本文提供的图例,来阅读tf代码。
-
本文所选的github,来自于开源组织labml_nn,该组织致力于使用pytorch复现经典论文的模型,并对代码做详细的注释,对初次接触新知识的读者来说非常友好。其含代码注释的地址在此 。在此把这个宝藏学习资源分享给大家。
在阅读本文前,强烈建议大家先阅读“模型架构篇”和“人人都能看懂的数学推理篇” 。
CV大模型系列文章导航(持续更新中):
?CV大模型系列之:扩散模型基石DDPM(模型架构篇)?
?CV大模型系列之:扩散模型基石DDPM(人人都能看懂的数学原理篇)?
?CV大模型系列之:扩散模型基石DDPM(源码解读与实操篇)?
?CV大模型系列之:全面解读VIT,它到底给植树人挖了多少坑?
一、DenoiseDiffusion
1.1 回顾扩散模型整体运作流程
在模型架构篇中,我们详细阐述过扩散模型的整体运作流程,现在我们将它再次梳理一遍,方便和我们的源码对齐。
如上图,扩散模型分为两步:
-
Diffusion Process:加噪过程。取一张干净的图片,逐步往上添加高斯噪声,执行T步后生成纯高斯噪声。这个过程中遵从的分布,记为)
-
Denoise Process:去噪过程。这一步中我们训练一个UNet架构的去噪模型,它吃和,然后去预测噪声,使得逼近Diffusion Process对应步骤中采样的真值噪声。这个过程中遵从的分布,我们记为。其中,表示UNet参数。
在数学原理篇中,我们阐述过,实际操作中:
-
Diffusion Process:设置采样规则,从去产生,而不是从去产生,这个规则具体为:。其中,可理解为人为设置的一连串超参数,随着t的增大而逐渐减小。再往上追溯一层,来自于人为设定一连串超参,这串超参随着t的增大而增大。
1.1.1 Training
由于不管对任何输入数据,不管对它的任何一个time_step,模型做的都是去预测一个来自高斯分布的噪声。因此整个训练过程可设计为:
- 从训练数据中,抽样出一张图片(即)
- 随机抽样出一个timestep。(即)
- 随机抽样出一个噪声(即)
- 计算:,其中表示UNet架构的去噪模型
- 计算梯度,更新模型,重复上面过程,直至收敛
(上面演示的是单条数据计算loss的过程,当然,整个过程也可以在batch范围内做,batch中单条数据计算loss的方法不变)
1.1.2 Sampling
当模型训练好后,就可以根据上图中的公式,从任意中推出了,其中:
,这一项是我们根据严谨的数学流程推导出来的。但在DDPM论文中,已通过实验证明可用,所以为了计算方便,源码中也选用了后者。
⚠️⚠️⚠️再次建议大家,在阅读源码篇前,先阅读“模型架构篇”和“人人都能看懂的数学推理篇”
1.2 DenoiseModel
DenoiseModel定义了上述的training步骤,我们直接来看代码(一切尽在注释中):
class DenoiseDiffusion:
"""
Denoise Diffusion
"""
def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):
"""
Params:
eps_model: UNet去噪模型,我们将在下文详细解读它的架构。
n_steps:训练总步数T
device:训练所用硬件
"""
super().__init__()
# 定义UNet架构模型
self.eps_model = eps_model
# 人为设置超参数beta,满足beta随着t的增大而增大,同时将beta搬运到训练硬件上
self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)
# 根据beta计算alpha(参见数学原理篇)
self.alpha = 1. - self.beta
# 根据alpha计算alpha_bar(参见数学原理篇)
self.alpha_bar = torch.cumprod(self.alpha, dim=0)
# 定义训练总步长
self.n_steps = n_steps
# sampling中的sigma_t
self.sigma2 = self.beta
def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Diffusion Process的中间步骤,根据x0和t,推导出xt所服从的高斯分布的mean和var
Params:
x0:来自训练数据的干净的图片
t:某一步time_step
Return:
mean: xt所服从的高斯分布的均值
var:xt所服从的高斯分布的方差
"""
# ----------------------------------------------------------------
# gather:人为定义的函数,从一连串超参中取出当前t对应的超参alpha_bar
# 由于xt = sqrt(alpha_bar_t) * x0 + sqrt(1-alpha_bar_t) * epsilon
# 其中epsilon~N(0, I)
# 因此根据高斯分布性质,xt~N(sqrt(alpha_bar_t) * x0, 1-alpha_bar_t)
# 即为本步中我们要求的mean和var
# ----------------------------------------------------------------
mean = gather(self.alpha_bar, t) ** 0.5 * x0
var = 1 - gather(self.alpha_bar, t)
return mean, var
def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):
"""
Diffusion Process,根据xt所服从的高斯分布的mean和var,求出xt
Params:
x0:来自训练数据的干净的图片
t:某一步time_step
Return:
xt: 第t时刻加完噪声的图片
"""
# ----------------------------------------------------------------
# xt = sqrt(alpha_bar_t) * x0 + sqrt(1-alpha_bar_t) * epsilon
# = mean + sqrt(var) * epsilon
# 其中,epsilon~N(0, I)
# ----------------------------------------------------------------
if eps is None:
eps = torch.randn_like(x0)
mean, var = self.q_xt_x0(x0, t)
return mean + (var ** 0.5) * eps
def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
"""
Sampling, 当模型训练好之后,根据x_t和t,推出x_{t-1}
Params:
x_t:t时刻的图片
t:某一步time_step
Return:
x_{t-1}: 第t-1时刻的图片
"""
# eps_model: 训练好的UNet去噪模型
# eps_theta: 用训练好的UNet去噪模型,预测第t步的噪声
eps_theta = self.eps_model(xt, t)
# 根据Sampling提供的公式,推导出x_{t-1}
alpha_bar = gather(self.alpha_bar, t)
alpha = gather(self.alpha, t)
eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5
mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
var = gather(self.sigma2, t)
eps = torch.randn(xt.shape, device=xt.device)
return mean + (var ** .5) * eps
def loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None):
"""
1. 随机抽取一个time_step t
2. 执行diffusion process(q_sample),随机生成噪声epsilon~N(0, I),
然后根据x0, t和epsilon计算xt
3. 使用UNet去噪模型(p_sample),根据xt和t得到预测噪声epsilon_theta
4. 计算mse_loss(epsilon, epsilon_theta)
【MSE只是众多可选loss设计中的一种,大家也可以自行设计loss函数】
Params:
x0:来自训练数据的干净的图片
noise: diffusion process中随机抽样的噪声epsilon~N(0, I)
Return:
loss: 真实噪声和预测噪声之间的loss
"""
batch_size = x0.shape[0]
# 随机抽样t
t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
# 如果为传入噪声,则从N(0, I)中抽样噪声
if noise is None:
noise = torch.randn_like(x0)
# 执行Diffusion process,计算xt
xt = self.q_sample(x0, t, eps=noise)
# 执行Denoise Process,得到预测的噪声epsilon_theta
eps_theta = self.eps_model(xt, t)
# 返回真实噪声和预测噪声之间的mse loss
return F.mse_loss(noise, eps_theta)
定义好DenoiseModel
后,我们就可以进一步定义train
函数来训练模型了,这里我们只截取代码中的核心部分,总体来说,每个epoch的训练分成两个部分:
train()
: 在这一部分中,我们创建模型(DenoiseModel
),遍历所有的batch,计算loss并做梯度更新。sample()
:每个epoch训练完毕后,我们根据上图sample部分中的公式,利用当前的模型,将一张高斯噪声()逐步还原回,将用于评估当前模型的效果(例如计算FID之类)
def train(self):
"""
单epoch训练DDPM
"""
# 遍历每一个batch(monit是自定义类,详情参见github完整代码)
for data in monit.iterate('Train', self.data_loader):
# step数+1(tracker是自定义类,详情参见github完整代码)
tracker.add_global_step()
# 将这个batch的数据移动到GPU上
data = data.to(self.device)
# 每个batch开始时,梯度清0
self.optimizer.zero_grad()
# self.diffusion即为DenoiseModel实例,执行forward,计算loss
loss = self.diffusion.loss(data)
# 计算梯度
loss.backward()
# 更新
self.optimizer.step()
# 保存loss,用于后续可视化之类的操作
tracker.save('loss', loss)
def sample(self):
"""
利用当前模型,将一张随机高斯噪声(xt)逐步还原回x0,
x0将用于评估模型效果(例如FID分数)
"""
with torch.no_grad():
# 随机抽取n_samples张纯高斯噪声
x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],
device=self.device)
# 对每一张噪声,按照sample公式,还原回x0
for t_ in monit.iterate('Sample', self.n_steps):
t = self.n_steps - t_ - 1
x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))
# 保存x0
tracker.save('sample', x)
def run(self):
"""
train主函数
"""
# 遍历每一个epoch
for _ in monit.loop(self.epochs):
# 训练模型
self.train()
# 利用当前训好的模型做sample,从xt还原x0,保存x0用于后续效果评估
self.sample()
# 再console上新起一行
tracker.new_line()
# 保存模型(experiment是自定义类,详情参见github代码)
experiment.save_checkpoint()
二、DDPM UNet
接下来,我们就来看UNet去噪模型具体长什么样子。
2.1 UNet主体架构
我们先来关注UNet主体架构,然后在下文继续看里面每一个模块的具体代码。
在模型架构篇中,我们曾说明过:
- DDPM UNet的输入是某一时刻的图片和用于表示该时刻的t向量(t向量的具体表示形式在下文会详细说明)
- DDPM UNet的输出是对t时刻噪声的预测。
- DDPM UNet是一个典型的Encoder-Decoder结构,在Encoder中,我们压缩图片大小,逐步提取图片特征;在Decoder中,我们逐步还原图片大小。由于压缩图片可能会损失掉信息,因此在decoder做还原时,我们会拼接Encoder层对应的特征图(skip connection),尽量减少信息损失。
假设我们有一张输入为32*32*3
大小的图片,则DDPM UNet的整体运作流程如下:
我们来看下相应的代码(一切尽在注释中),同时,建议大家在阅读源码的同时,整一些加数据,亲自跑一遍主体模型,打印出output_shape,更方便大家理解源码:
class UNet(Module):
"""
DDPM UNet去噪模型主体架构
"""
def __init__(self, image_channels: int = 3, n_channels: int = 64,
ch_mults: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4),
is_attn: Union[Tuple[bool, ...], List[int]] = (False, False, True, True),
n_blocks: int = 2):
"""
Params:
image_channels:原始输入图片的channel数,对RGB图像来说就是3
n_channels: 在进UNet之前,会对原始图片做一次初步卷积,该初步卷积对应的
out_channel数,也就是图中左上角的第一个墨绿色箭头
ch_mults: 在Encoder下采样的每一层的out_channels倍数,
例如ch_mults[i] = 2,表示第i层特征图的out_channel数,
是第i-1层的2倍。Decoder上采样时也是同理,用的是反转后的ch_mults
is_attn: 在Encoder下采样/Decoder上采样的每一层,是否要在CNN做特征提取后再引入attention
(会在下文对该结构进行详细说明)
n_blocks: 在Encoder下采样/Decoder下采样的每一层,需要用多少个DownBlock/UpBlock(见图),
Deocder层最终使用的UpBlock数=n_blocks + 1
【到此为止没有完全看懂注释也没关系,可以一遍打开示意图,一遍继续往下阅读源码,就能满满加深理解】
"""
super().__init__()
# 在Encoder下采样/Decoder上采样的过程中,图像依次缩小/放大,
# 每次变动都会产生一个新的图像分辨率
# 这里指的就是不同图像分辨率的个数,也可以理解成是Encoder/Decoder的层数
n_resolutions = len(ch_mults)
# 对原始图片做预处理,例如图中,将32*32*3 -> 32*32*64
self.image_proj = nn.Conv2d(image_channels, n_channels, kernel_size=(3, 3), padding=(1, 1))
# time_embedding,TimeEmbedding是nn.Module子类,我们会在下文详细讲解它的属性和forward方法
self.time_emb = TimeEmbedding(n_channels * 4)
# --------------------------
# 定义Encoder部分
# --------------------------
# down列表中的每个元素表示Encoder的每一层
down = []
# 初始化out_channel和in_channel
out_channels = in_channels = n_channels
# 遍历每一层
for i in range(n_resolutions):
# 根据设定好的规则,得到该层的out_channel
out_channels = in_channels * ch_mults[i]
# 根据设定好的规则,每一层有n_blocks个DownBlock
for _ in range(n_blocks):
down.append(DownBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
in_channels = out_channels
# 对Encoder来说,每一层结束后,我们都做一次下采样,但Encoder的最后一层不做下采样
if i < n_resolutions - 1:
down.append(Downsample(in_channels))
# self.down即是完整的Encoder部分
self.down = nn.ModuleList(down)
# --------------------------
# 定义Middle部分
# --------------------------
self.middle = MiddleBlock(out_channels, n_channels * 4, )
# --------------------------
# 定义Decoder部分
# --------------------------
# 和Encoder部分基本一致,可对照绘制的架构图阅读
up = []
in_channels = out_channels
for i in reversed(range(n_resolutions)):
# `n_blocks` at the same resolution
out_channels = in_channels
for _ in range(n_blocks):
up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
out_channels = in_channels // ch_mults[i]
up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
in_channels = out_channels
if i > 0:
up.append(Upsample(in_channels))
# self.up即是完整的Decoder部分
self.up = nn.ModuleList(up)
# 定义group_norm, 激活函数,和最后一层的CNN(用于将Decoder最上一层的特征图还原成原始尺寸)
self.norm = nn.GroupNorm(8, n_channels)
self.act = Swish()
self.final = nn.Conv2d(in_channels, image_channels, kernel_size=(3, 3), padding=(1, 1))
def forward(self, x: torch.Tensor, t: torch.Tensor):
"""
Params:
x: 输入数据xt,尺寸大小为(batch_size, in_channels, height, width)
t: 输入数据t,尺寸大小为(batch_size)
"""
# 取得time_embedding
t = self.time_emb(t)
# 对原始图片做初步CNN处理
x = self.image_proj(x)
# -----------------------
# Encoder
# -----------------------
h = [x]
# First half of U-Net
for m in self.down:
x = m(x, t)
h.append(x)
# -----------------------
# Middle
# -----------------------
x = self.middle(x, t)
# -----------------------
# Decoder
# -----------------------
for m in self.up:
if isinstance(m, Upsample):
x = m(x, t)
else:
s = h.pop()
# skip_connection
x = torch.cat((x, s), dim=1)
x = m(x, t)
return self.final(self.act(self.norm(x)))
到这里,我们就把DDPM UNet的主体架构讲完了,接下来我们来看架构中的子模块,主要分为以下部分:
- DownBlock(Encoder层,也就是图中每一个红色箭头)
- DownSample(Encoder层间的下采样,也就是图中每一个浅绿色箭头)
- UpBlock(Decoder层,也就是图中每个蓝色箭头)
- UpSample(Decoder曾间的上采样,也就是图中每一个紫色箭头)
- TimeEmbedding(针对整型时刻t做的向量化处理,也就是图中每一个青色箭头)
2.2 DownBlock和UpBlock
DownBlock和UpBlock的内部架构非常相似,都是Redisual + Attention,其中Attention部分不是必须的,是可选的。我们在这里只摘取DownBlock部分的代码进行讲解,UpBlock部分留给大家自己看。
图中已经绘制的很详细了,可以直接配合代码阅读。需要关注的是,虚线部分即为“残差连接”(Residual Connection) ,而残差连接之上引入的虚线框Conv的意思是,如果in_c = out_c,则对in_c做一次卷积,使得其通道数等于out_c后,再相加;否则将直接相加。
class ResidualBlock(Module):
"""
每一个Residual block都有两层CNN做特征提取
"""
def __init__(self, in_channels: int, out_channels: int, time_channels: int,
n_groups: int = 32, dropout: float = 0.1):
"""
Params:
in_channels: 输入图片的channel数量
out_channels: 经过residual block后输出特征图的channel数量
time_channels:time_embedding的向量维度,例如t原来是个整型,值为1,表示时刻1,
现在要将其变成维度为(1, time_channels)的向量
n_groups: Group Norm中的超参
dropout: dropout rate
"""
super().__init__()
# 第一层卷积 = Group Norm + CNN
self.norm1 = nn.GroupNorm(n_groups, in_channels)
self.act1 = Swish()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))
# 第二层卷积 = Group Norm + CNN
self.norm2 = nn.GroupNorm(n_groups, out_channels)
self.act2 = Swish()
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))
# 当in_c = out_c时,残差连接直接将输入输出相加;
# 当in_c != out_c时,对输入数据做一次卷积,将其通道数变成和out_c一致,再和输出相加
if in_channels != out_channels:
self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
else:
self.shortcut = nn.Identity()
# t向量的维度time_channels可能不等于out_c,所以我们要对起做一次线性转换
self.time_emb = nn.Linear(time_channels, out_channels)
self.time_act = Swish()
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, t: torch.Tensor):
"""
Params:
x: 输入数据xt,尺寸大小为(batch_size, in_channels, height, width)
t: 输入数据t,尺寸大小为(batch_size, time_c)
【配合图例进行阅读】
"""
# 1.输入数据先过一层卷积
h = self.conv1(self.act1(self.norm1(x)))
# 2. 对time_embedding向量,通过线性层使time_c变为out_c,再和输入数据的特征图相加
h += self.time_emb(self.time_act(t))[:, :, None, None]
# 3、过第二层卷积
h = self.conv2(self.dropout(self.act2(self.norm2(h))))
# 4、返回残差连接后的结果
return h + self.shortcut(x)
class AttentionBlock(Module):
"""
Attention模块
和Transformer中的multi-head attention原理及实现方式一致
"""
def __init__(self, n_channels: int, n_heads: int = 1, d_k: int = None, n_groups: int = 32):
"""
Params:
n_channels:等待做attention操作的特征图的channel数
n_heads: attention头数
d_k: 每一个attention头处理的向量维度
n_groups: Group Norm超参数
"""
super().__init__()
# 一般而言,d_k = n_channels // n_heads,需保证n_channels能被n_heads整除
if d_k is None:
d_k = n_channels
# 定义Group Norm
self.norm = nn.GroupNorm(n_groups, n_channels)
# Multi-head attention层: 定义输入token分别和q,k,v矩阵相乘后的结果
self.projection = nn.Linear(n_channels, n_heads * d_k * 3)
# MLP层
self.output = nn.Linear(n_heads * d_k, n_channels)
self.scale = d_k ** -0.5
self.n_heads = n_heads
self.d_k = d_k
def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None):
"""
Params:
x: 输入数据xt,尺寸大小为(batch_size, in_channels, height, width)
t: 输入数据t,尺寸大小为(batch_size, time_c)
【配合图例进行阅读】
"""
# t并没有用到,但是为了和ResidualBlock定义方式一致,这里也引入了t
_ = t
# 获取shape
batch_size, n_channels, height, width = x.shape
# 将输入数据的shape改为(batch_size, height*weight, n_channels)
# 这三个维度分别等同于transformer输入中的(batch_size, seq_length, token_embedding)
# (参见图例)
x = x.view(batch_size, n_channels, -1).permute(0, 2, 1)
# 计算输入过矩阵q,k,v的结果,self.projection通过矩阵计算,一次性把这三个结果出出来
# 也就是qkv矩阵是三个结果的拼接
# 其shape为:(batch_size, height*weight, n_heads, 3 * d_k)
qkv = self.projection(x).view(batch_size, -1, self.n_heads, 3 * self.d_k)
# 将拼接结果切开,每一个结果的shape为(batch_size, height*weight, n_heads, d_k)
q, k, v = torch.chunk(qkv, 3, dim=-1)
# 以下是正常计算attention score的过程,不再做说明
attn = torch.einsum('bihd,bjhd->bijh', q, k) * self.scale
attn = attn.softmax(dim=2)
res = torch.einsum('bijh,bjhd->bihd', attn, v)
# 将结果reshape成(batch_size, height*weight,, n_heads * d_k)
# 复习一下:n_heads * d_k = n_channels
res = res.view(batch_size, -1, self.n_heads * self.d_k)
# MLP层,输出结果shape为(batch_size, height*weight,, n_channels)
res = self.output(res)
# 残差连接
res += x
# 将输出结果从序列形式还原成图像形式,
# shape为(batch_size, n_channels, height, width)
res = res.permute(0, 2, 1).view(batch_size, n_channels, height, width)
return res
class DownBlock(Module):
"""
Down block,即Encoder中每一层的核心处理逻辑
DownBlock = ResidualBlock + AttentionBlock
在我们的例子中,Encoder的每一层都有2个DownBlock
"""
def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
super().__init__()
self.res = ResidualBlock(in_channels, out_channels, time_channels)
if has_attn:
self.attn = AttentionBlock(out_channels)
else:
self.attn = nn.Identity()
def forward(self, x: torch.Tensor, t: torch.Tensor):
x = self.res(x, t)
x = self.attn(x)
return x
2.3 TimeEmbedding
在2.2中,我们频繁看见time_embedding向量,那么它是怎么来的呢?
概括来说,原始的time_step是一个整数,例如1表示第一个时刻,2表示第二个时刻。
- 我们定义TimeEmbedding模块,将这个整数包装成维度=time_channel的向量,这个包装方式和Transformer中函数式位置编码的包装方式一致。
- 然后,再实际应用到time_emebdding向量时,再通过一个简单的线性层,将其维度从time_channel转变为对应特征图的out_channel,使其能够和特征图相加。
具体的过程再图中已经绘制得很清楚了,我们就直接来看代码吧(一切尽在注释中):
class TimeEmbedding(nn.Module):
"""
TimeEmbedding模块将把整型t,以Transformer函数式位置编码的方式,映射成向量,
其shape为(batch_size, time_channel)
"""
def __init__(self, n_channels: int):
"""
Params:
n_channels:即time_channel
"""
super().__init__()
self.n_channels = n_channels
self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)
self.act = Swish()
self.lin2 = nn.Linear(self.n_channels, self.n_channels)
def forward(self, t: torch.Tensor):
"""
Params:
t: 维度(batch_size),整型时刻t
"""
# 以下转换方法和Transformer的位置编码一致
# 【强烈建议大家动手跑一遍,打印出每一个步骤的结果和尺寸,更方便理解】
half_dim = self.n_channels // 8
emb = math.log(10_000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
emb = t[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=1)
# Transform with the MLP
emb = self.act(self.lin1(emb))
emb = self.lin2(emb)
# 输出维度(batch_size, time_channels)
return emb
2.4 DowSample和UpSample
这两块分别起到“压缩特征”和“还原特征”的作用,比较简单,我们直接来看代码:
class Upsample(nn.Module):
"""
上采样
"""
def __init__(self, n_channels):
super().__init__()
self.conv = nn.ConvTranspose2d(n_channels, n_channels, (4, 4), (2, 2), (1, 1))
def forward(self, x: torch.Tensor, t: torch.Tensor):
_ = t
return self.conv(x)
class Downsample(nn.Module):
"""
下采样
"""
def __init__(self, n_channels):
super().__init__()
self.conv = nn.Conv2d(n_channels, n_channels, (3, 3), (2, 2), (1, 1))
def forward(self, x: torch.Tensor, t: torch.Tensor):
_ = t
return self.conv(x)
2.5 MiddleBlock
MiddleBlock = ResidualBlock + AttentionBlock + ResidualBlock组成,具体结构如下图:
我们在上文讨论过ResidualBlock和AttentionBlock的具体实现代码,这里就不再赘述,MiddleBlock的代码如下(一切尽在注释中):
class MiddleBlock(Module):
"""
MiddleBlock
这是UNet结构中,连接Encoder和Decoder的最下层部分,
MiddleBlock = ResidualBlock + AttentionBlock + ResidualBlock
"""
def __init__(self, n_channels: int, time_channels: int):
super().__init__()
self.res1 = ResidualBlock(n_channels, n_channels, time_channels)
self.attn = AttentionBlock(n_channels)
self.res2 = ResidualBlock(n_channels, n_channels, time_channels)
def forward(self, x: torch.Tensor, t: torch.Tensor):
x = self.res1(x, t)
x = self.attn(x)
x = self.res2(x, t)
return x
好了,到目前为止,我们已经将DDPM整体架构的代码解读完毕,接下来,我们动手来看下,如何使用DDPM还原MNIST数据集吧
三、实操:使用扩散模型还原MNIST数据集
在这个Google Colab链接中提供了快速开启DDPM训练的快捷方式,并能从中看到每个epoch训练后,对模型做sampling后的中间结果,方便我们观测模型是如何一步步进行学习的。打开google colab需要翻墙,没有墙的朋友,可以clone github仓库在本地进行测试。
训练前期采样结果(随机抽取16个timestep,向前进行还原):
训练后期采样数据(只有10个epoch),可观测到已初具数字形态: