해당 포스트에서는 DDPM : Denoising Diffusion Probabilistic Models 논문에서 핵심적인 부분들을 코드와 함께 알아보겠습니다.
논문을 전체적으로 읽고 싶으신 분은 Paper Reading으로 이동해주세요!
1. DDPM의 Training & Sampling
1.1 Training
DDPM의 Objective function을 구하는 과정을 간단히 한 번 살펴보겠습니다.
NLL(Negative Log Likelihood)
\[\mathbb{E}\left[-\log\ p_\theta(\mathbf{x}_0)\right]\]NLL을 variational inference를 통해 Negative ELBO term을 최소화하는 문제로 변경
\[\mathbb{E}\left[-\log\ p_\theta(\mathbf{x}_0)\right] \le \mathbb{E}_q\left[-\log\ \frac{p(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T}|\mathbf{x}_0)}\right] = \mathbb{E}_q\left[-\log\ p(\mathbf{x}_T) -\sum_{t=1}^T\log \frac{p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)}{q(\mathbf{x}_t|\mathbf{x}_{t-1})}\right] =: L\]$\mathbf{x}_0$를 condition으로 추가
\[\begin{align} L &= \mathbb{E}_q\left[-\log\ p(\mathbf{x}_T) -\sum_{t=1}^T\log \frac{p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)}{q(\mathbf{x}_t|\mathbf{x}_{t-1})}\right] \\ &= D_{KL}(q(\mathbf{x}_T|\mathbf{x}_0)||p(\mathbf{x}_T)) + \sum_{t=2}^T D_{KL}(q(\mathbf{x}_{t-1}|\mathbf{x}_t, \mathbf{x}_{0}) || p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)) + \mathbb{E}_{\mathbf{x}_T\ \sim\ q(\mathbf{x}_T|\mathbf{x}_0)}\left[- \log\ p_\theta(\mathbf{x}_{0}|\mathbf{x}_1)\right] \\ &= L_T + L_{1:T-1} + L_0 \end{align}\]$\beta_t$를 constant로 고정하면서 학습해야하는 term이 줄어듬
\[\begin{align} L &= L_T + L_{1:T-1} + L_0 \\ &\approx L_{1:T-1} + L_0 \end{align}\]$L_{1:T-1}$의 $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$의 분산($\sigma_t^2$)을 timestep에 의존적이며 학습이 필요없는 $\beta_t$ 또는 $\tilde\beta_t$로 설정하면서 더 간단히 표현이 가능
\[\begin{align} L &= L_{1:T-1} + L_0 \\ &= \sum_{t=2}^T \mathbb{E}_q\left[\frac{1}{2\sigma_t^2}\lVert\tilde{\mu}_t(\mathbf{x}_t, \mathbf{x}_0) - \mu_\theta(\mathbf{x}_t, t)\rVert^2\right] + L_0 \end{align}\]Reparameterization trick과 function approximator $\epsilon_\theta$를 통해 더 간단히 표현이 가능
\[\begin{align} L &= \sum_{t=2}^T \mathbb{E}_q\left[\frac{1}{2\sigma_t^2}\lVert\tilde{\mu}_t(\mathbf{x}_t, \mathbf{x}_0) - \mu_\theta(\mathbf{x}_t, t)\rVert^2\right] + L_0 \\ &= \mathbb{E}_{\mathbf{x}_0,\ \epsilon}\left[\frac{\beta_t^2}{2\sigma_t^2\cdot\alpha_t\cdot(1-\bar{\alpha}_t)}\lVert\epsilon - \epsilon_\theta(\mathbf{x}_t, t)\rVert^2\right] + L_0 \end{align}\]$L_0$를 discrete decoder로 정의하고 $L_{1:T-1}$에서 상수를 제거하면 $L$를 한 번에 표현이 가능
\[\begin{align} L &= \mathbb{E}_{\mathbf{x}_0,\ \epsilon}\left[\frac{\beta_t^2}{2\sigma_t^2\cdot\alpha_t\cdot(1-\bar{\alpha}_t)}\lVert\epsilon - \epsilon_\theta(\mathbf{x}_t, t)\rVert^2\right] + L_0 &\\ &\approx \mathbb{E}_{\mathbf{x}_0,\ \epsilon}\left[\lVert\epsilon - \epsilon_\theta(\mathbf{x}_t, t)\rVert^2\right] &\\ &= \mathbb{E}_{\mathbf{x}_0,\ \epsilon}\left[\lVert\epsilon - \epsilon_\theta(\sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1-\bar{\alpha}_t}\cdot\epsilon, t)\rVert^2\right] &\because \mathbf{x}_t \sim q(\mathbf{x}_t|\mathbf{x}_0) = \mathcal{N}(\mathbf{x}_t; \sqrt{\bar{\alpha}_t}\mathbf{x}_0, (1-\bar{\alpha}_t)\mathrm{I}) \end{align}\]
1.2 Sampling
Training의 5번을 최소화 하려면 아래와 같이 $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$의 평균($\mu_\theta(\mathbf{x}_t, t)$)이 아래와 같아야 함을 알 수 있습니다.
\[\mu_\theta(\mathbf{x}_t, t) = \tilde{\mu}_t(\mathbf{x}_t, \mathbf{x}_0)\]$\mathbf{x}_t \sim q(\mathbf{x}_t|\mathbf{x}_0) = \mathcal{N}(\mathbf{x}_t; \sqrt{\bar{\alpha}_t}\mathbf{x}_0, (1-\bar{\alpha}_t)\mathrm{I})$이므로 reparameterization tirck을 통해 $\mathbf{x}_t$를 $\sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1-\bar{\alpha}_t}\cdot\epsilon$로 표현할 수 있고 해당 식을 $\mathbf{x}_0$로 정리하여 $\tilde{\mu}_t(\mathbf{x}_t, \mathbf{x}_0)$를 표현하면 아래와 같습니다.
\[\begin{align} \mu_\theta(\mathbf{x}_t, t) &= \tilde{\mu}_t(\mathbf{x}_t, \mathbf{x}_0) \\ &= \tilde{\mu}_t\left(\mathbf{x}_t, \frac{1}{\sqrt{\alpha_t}}(\mathbf{x}_t - \sqrt{1-\bar{\alpha}_t}\epsilon_\theta(\mathbf{x}_t))\right) \end{align}\]논문에서의 공식 7번($\tilde{\mu}_t(\mathbf{x}_{t}, \mathbf{x}_{0}) := \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t}\mathbf{x}_0 + \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}\mathbf{x}_t$)을 통해 위의 식을 정리하면 아래와 같아집니다.
\[\begin{align} \mu_\theta(\mathbf{x}_t, t) &= \tilde{\mu}_t(\mathbf{x}_t, \mathbf{x}_0) \\ &= \tilde{\mu}_t\left(\mathbf{x}_t, \frac{1}{\sqrt{\alpha_t}}(\mathbf{x}_t - \sqrt{1-\bar{\alpha}_t}\epsilon_\theta(\mathbf{x}_t))\right) \\ &= \frac{1}{\sqrt{\alpha_t}}\Bigg(\mathbf{x}_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}}\cdot\epsilon_\theta(\mathbf{x}_t, t)\Bigg) \end{align}\]위에서 찾은 평균과 논문에서 설정한 분산을 통해 Sampling 할 수 있습니다. 이를 총 정리하면 아래와 같습니다.
\[\begin{align} p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t) &= \mathcal{N}(\mathbf{x}_{t-1}; \mu_\theta(\mathbf{x}_t, t), \Sigma_\theta(\mathbf{x}_t, t)) \qquad & \\ &= \mathcal{N}(\mathbf{x}_{t-1}; \mu_\theta(\mathbf{x}_t, t), \sigma^2_t) \qquad & \sigma^2_t = \beta_t\ \text{or}\ \sigma^2_t = \tilde{\beta}_t \\ &= \mathcal{N}\Bigg(\mathbf{x}_{t-1}; \frac{1}{\sqrt{\alpha_t}}\bigg(\mathbf{x}_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}}\cdot\epsilon_\theta(\mathbf{x}_t, t)\bigg), \sigma^2_t\Bigg) \qquad & \\ \mathbf{x}_{t-1} &= \frac{1}{\sqrt{\alpha_t}}\bigg(\mathbf{x}_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}}\cdot\epsilon_\theta(\mathbf{x}_t, t)\bigg) + \sigma_t\cdot \mathbf{z} \quad \text{where}\ \mathbf{z} \sim \mathcal{N}(\mathrm{0}, \mathrm{I}) \qquad & \because \text{Reparameterization trick}\\ \end{align}\]
1.3 Training을 위한 objective function과 Sampling을 코드로 나타내면 아래와 같습니다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
class DDPM:
def __init__(self, backbone: nn.Module, T: int, device:torch.device):
super().__init__()
# Paper - 4. Experiments : U-Net backbone
self.backbone = backbone
# Paper - 4. Expeirments : Set T = 1000
self.T = T
# Paper - 2. Background : \beta_t -> reparameterization or constant
# Paper - 4. Expeirments : Forward process variances to constants increasing linearly
self.beta_t = torch.linspace(1e-4, 0.02, T).to(device)
# Paper - 2. Background : Closed form
self.alpha_t = 1. - self.beta_t
self.alpha_bar_t = torch.cumprod(self.alpha_t, dim = 0)
# 3.2 Reverse process :
# set \Sigma_{\theta}(x_t, t) = \sigma^2_tI to untrained time dependent constants
# and \sigma^2_t = \beta_t or \sigma^2_t = \tilde{beta}_t had similar results
self.sigma2t = self.beta_t
def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
q(x_t|x_0) = N(x_t; \sqrt(\bar{\alpha}_t)x_0, (1 - \bar{\alpha}_t))
"""
# 주어진 t에 대한 모든 \bar{\alpha}_t
alpha_bar_t = self.alpha_bar_t.gather(-1, t).reshape(-1, 1, 1, 1)
# q(x_t|x_0)의 평균 = \sqrt(\bar{\alpha}_t) * x0
mean = alpha_bar_t ** 0.5 *x0
# q(x_t|x_0)의 분산 = 1 - \bar{\alpha}_t
var = 1 - alpha_bar_t
return mean, var
def q_sample(self, x0: torch.Tensor, t: torch.Tensor, epsilon: Optional[torch.Tensor] = None):
"""
Forward process (Diffusion Process)
q(x_t|x_0) = N(x_t; \sqrt(\bar{\alpha}_t)x_0, (1 - \bar{\alpha}_t))
x_t = mean + \sqrt(var)*\epsilon (\epsilon ~ N(0, I))
= \sqrt(\bar{\alpha}_t)x_0 + (1 - \bar{\alpha}_t) ** 0.5 * \epsilon
"""
if epsilon is None:
epsilon = torch.randn_like(x0)
mean, var = self.q_xt_x0(x0, t)
xt = mean + (var ** 0.5) * epsilon
return xt
def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
"""
Reverse process (Denoising process)
x_{t-1} ~ p_\theta(x_{t-1})|x_t)
p_\theta(x_{t-1}|x_t) := N(x_{t-1}; \mu_\theta, \Sigma_\theta)
:= N(x_{t-1}; \mu_\theta, \beta_t)
:= N(x_{t-1}; 1/\sqrt{\alpha_t}(x_t - (\beta_t/\sqrt{1 - \bar{\alpha}_t})\epsilon_\theta(x_t, t)), \beta_t)
x_{t-1} = \mu_\theta + \sqrt(\Sigma_\theta) * z (z ~ N(0, I))
= 1/\sqrt{\alpha_t}(x_t - (\beta_t/\sqrt{1 - \bar{\alpha}_t})\epsilon_\theta(x_t, t)) + \sqrt(\beta_t) * z
= denoise_xt
"""
epsilon_theta = self.backbone(xt, t)
# 주어진 t에 대한 모든 \bar{\alpha}_t
alpha_bar_t = self.alpha_bar_t.gather(-1, t).reshape(-1, 1, 1, 1)
# 주어진 t에 대한 모든 \beta_t
beta_t = self.beta_t.gather(-1, t).reshape(-1, 1, 1, 1)
# 주어진 t에 대한 모든 \alpha_t
alpha_t = 1 - beta_t
# \mu_\theta
mean = 1/(alpha_t**0.5) * (xt - (beta_t/(1-alpha_bar_t)**0.5)*epsilon_theta)
# \Sigma_\theta = \sigma^2_t = \beta_t
var = beta_t
# z ~ N(0, I)
z = torch.randn(xt.shape, device=xt.device)
# x_{t-1} = \mu_\theta + \sqrt(\Sigma_\theta) * z
denoise_xt = mean + (var ** 0.5) * z
return denoise_xt
def loss_simple(self, x0: torch.Tensor, epsilon: Optional[torch.Tensor]=None):
"""
L_simple = L2(epsilon, epsilon_theta)
= L2(epsilon, xt ~ q(xt|x0))
Function approximator \epsilon_theta와 Gaussian noise(\epsilon)를 통해 training
"""
batch_size = x0.shape[0]
t = torch.randint(0, self.T, (batch_size,), device=x0.device, dtype=torch.long)
if epsilon is None:
epsilon = torch.randn_like(x0)
xt = self.q_sample(x0, t, epsilon=epsilon)
epsilon_theta = self.backbone(xt, t)
return F.mse_loss(epsilon, epsilon_theta)
2. DDPM의 backbone (U-Net)
DDPM에서는 Unmasked PixelCNN++과 비슷한 U-Net backbone을 사용했으며 group normalization을 사용했습니다.
또한 transformer에서 사용한 positional encoding을 time embedding을 하는 데 사용했으며 $16 \times 16$ 사이에 있는 convolution network에는 self-attention을 사용했습니다. 즉, $8 \times 8$, $4 \times 4$에 self-attention을 적용했다는 이야기가 됩니다.
2.1 U-Net 그림으로 살펴보기
먼저 모델의 구성을 그림으로 그리면 아래와 같습니다. (코드 먼저 보기)
※ 아래 코드의 Up, Down은 각각 UpBlock, DonwBlock의 집합을 이야기 하며 하나의 라인을 뜻합니다.
※ 그림에서 색이 칠해진 부분은 U-Net 구조에서 concatenate 되는 부분을 뜻합니다.
2.2 U-Net 코드
위의 그림을 통해 살펴본 U-Net 구조를 코드를 나타내면 아래와 같습니다. (그림 다시 보기)
※ 코드의 Up, Down은 각각 UpBlock, DonwBlock의 집합을 이야기 하며 하나의 라인을 뜻합니다.(그림에는 표시가 되어 있지 않음)
※ 주석을 함께 보시면 논문에 어떻게 기술되어 있는지 비교할 수 있습니다.
2.2.1 Time Embedding
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class TimeEmbedding(nn.Module):
def __init__(self, n_channels: int):
"""
Appendix B. Experiments details about Residual blocks
- Diffusion time t is specified by adding the Transformer sinusoidal position embedding into each residual block.
"""
super().__init__()
self.n_channels = n_channels
self.embed1 = nn.Linear(self.n_channels // 4, self.n_channels)
self.silu = nn.SiLU()
self.embed2 = nn.Linear(self.n_channels, self.n_channels)
def forward(self, t: torch.Tensor):
half_dim = self.n_channels // 8
emb = math.log(10_000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim) * -emb)
emb = t[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=1)
t = self.embed1(emb)
t = self.silu(t)
t = self.embed2(t)
return t
2.2.2 Residual Block
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class ResidualBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int, time_channels: int, n_groups: int, dropout_rate: float):
"""
Appendix B. Experiments details about Residual blocks
- Wide ResNet
- We replaced weight normalization with group normalization to make the implementation simpler
- Diffusion time t is specified by adding the Transformer sinusoidal position embedding into each residual block
- We set the dropout rate on CIFAR10 to 0.1 by sweeping over the values {0.1, 0.2, 0.3, 0.4}
"""
super().__init__()
# 현재 n_feature에 맞춰 time embedding 변경
self.time_proj = nn.Linear(time_channels, out_channels)
self.time_silu = nn.SiLU()
self.norm1 = nn.GroupNorm(n_groups, in_channels)
self.silu1 = nn.SiLU()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = nn.GroupNorm(n_groups, out_channels)
self.silu2 = nn.SiLU()
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if in_channels != out_channels:
self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
else:
self.shortcut = nn.Identity()
self.dropout = nn.Dropout(dropout_rate)
def forward(self, x: torch.Tensor, t: torch.Tensor):
residual = x
x = self.norm1(x)
x = self.silu1(x)
x = self.conv1(x)
t = self.time_silu(t)
t = self.time_proj(t)
t = t[:, :, None, None]
x += t
x = self.norm2(x)
x = self.silu2(x)
x = self.dropout(x)
x = self.conv2(x)
residual = self.shortcut(residual)
x += residual
return x
2.2.3 Self Attention Block
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class SelfAttentionBlock(nn.Module):
def __init__(self, n_channels: int, n_heads: int, n_groups: int):
"""
Appendix B. Experiments details about Residual blocks
- All models have self-attention blocks at the 16 × 16 resolution between the convolutional blocks
- We replaced weight normalization with group normalization to make the implementation simpler.
"""
super().__init__()
self.n_heads = n_heads
self.d_k = n_channels
self.proj = nn.Linear(n_channels, n_heads * self.d_k * 3) # W_q, W_k, W_v
self.output = nn.Linear(n_heads * self.d_k, n_channels) # W_o
self.norm = nn.GroupNorm(n_groups, n_channels)
def forward(self, x: torch.Tensor, t: torch.Tensor):
batch_size, n_channels, height, width = x.shape
# Normalization
x = self.norm(x)
# Query, Key, Value
x = x.view(batch_size, n_channels, -1).permute(0, 2, 1)
proj_x = self.proj(x)
qkv = proj_x.view(batch_size, -1, self.n_heads, self.d_k * 3)
q, k, v = torch.chunk(qkv, 3, dim = -1)
# Multi-Head Attention
attention_score = torch.einsum("bihd, bjhd -> bijh", q, k) * (self.d_k ** -0.5) # QK^T/\sqrt(d_k)
attention_score = attention_score.softmax(dim=2) # Softmax(QK^T/\sqrt(d_k))
attention = torch.einsum("bijh, bjhd -> bihd", attention_score, v) # Softmax(QK^T/\sqrt(d_k)) * V
attention = attention.view(batch_size, -1, self.n_heads * self.d_k)
attention_output = self.output(attention)
# Skip Connection
x += attention_output
# Recover shape
x = x.permute(0, 2, 1).view(batch_size, n_channels, height, width)
return x
2.2.4 DownBlock
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class DownBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int, has_attention: bool,
time_channels: int, n_groups: int, dropout_rate: float, n_heads: int):
"""
Appendix B. Experiments details about Residual blocks
- All models have self-attention blocks at the 16 × 16 resolution between the convolutional blocks
"""
super().__init__()
self.has_attention = has_attention
self.residual_block = ResidualBlock(in_channels, out_channels, time_channels, n_groups, dropout_rate)
self.attn = SelfAttentionBlock(out_channels, n_heads, n_groups)
def forward(self, x: torch.Tensor, t: torch.Tensor):
x = self.residual_block(x, t)
if self.has_attention:
x = self.attn(x, t)
return x
2.2.5 Down
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class Down(nn.Module):
def __init__(self, in_channels: int, out_channels: int, has_attention: bool,
time_channels: int, n_groups: int, dropout_rate: float, n_heads: int, is_down: bool):
"""
Appendix B. Experiments details about Residual blocks
- All models have two convolutional residual blocks per resolution level
"""
super().__init__()
self.is_down = is_down
self.down_sample = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=1)
self.down_block1 = DownBlock(in_channels, out_channels, has_attention, time_channels, n_groups, dropout_rate, n_heads)
self.down_block2 = DownBlock(out_channels, out_channels, has_attention, time_channels, n_groups, dropout_rate, n_heads)
def forward(self, x, t):
result = [] # For skip connection
if self.is_down:
x = self.down_sample(x)
result.append(x)
x = self.down_block1(x, t)
result.append(x)
x = self.down_block2(x, t)
result.append(x)
return x, result
2.2.6 Middle
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class Middle(nn.Module):
def __init__(self, in_channels: int, out_channels: int, time_channels: int,
n_groups: int, dropout_rate: float, n_heads: int):
super().__init__()
self.residual_block1 = ResidualBlock(in_channels, out_channels, time_channels, n_groups, dropout_rate)
self.attn = SelfAttentionBlock(out_channels, n_heads, n_groups)
self.residual_block2 = ResidualBlock(in_channels, out_channels, time_channels, n_groups, dropout_rate)
def forward(self, x, t):
x = self.residual_block1(x, t)
x = self.attn(x, t)
x = self.residual_block2(x, t)
return x
2.2.7 UpBlock
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class UpBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int, has_attention: bool, time_channels: int,
n_groups: int, dropout_rate: float, n_heads: int):
"""
Appendix B. Experiments details about Residual blocks
- All models have self-attention blocks at the 16 × 16 resolution between the convolutional blocks
"""
super().__init__()
self.has_attention = has_attention
self.residual_block = ResidualBlock(in_channels + out_channels, out_channels, time_channels, n_groups, dropout_rate)
self.attn = SelfAttentionBlock(out_channels, n_heads, n_groups)
def forward(self, x, t):
x = self.residual_block(x, t)
if self.has_attention:
x = self.attn(x, t)
return x
2.2.8 Up
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class Up(nn.Module):
def __init__(self, in_channels: int, out_channels: int, has_attention: bool,
time_channels: int, n_groups: int, dropout_rate: float, n_heads: int, is_up: bool):
"""
Appendix B. Experiments details about Residual blocks
- All models have two convolutional residual blocks per resolution level
"""
super().__init__()
self.is_up = is_up
self.up_sample = nn.ConvTranspose2d(out_channels, out_channels, kernel_size=4, stride=2, padding=1)
self.up_block1 = UpBlock(in_channels, in_channels, has_attention, time_channels, n_groups, dropout_rate, n_heads) # down_{3-i}의 결과와 concatenate
self.up_block2 = UpBlock(in_channels, in_channels, has_attention, time_channels, n_groups, dropout_rate, n_heads) # down_{3-i}의 결과와 concatenate
self.up_block3 = UpBlock(in_channels, out_channels, has_attention, time_channels, n_groups, dropout_rate, n_heads) # down_{3-i}의 결과와 concatenate
def forward(self, x, t, down):
x = torch.cat((x, down[-1]), dim=1)
x = self.up_block1(x, t)
x = torch.cat((x, down[-2]), dim=1)
x = self.up_block2(x, t)
x = torch.cat((x, down[-3]), dim=1)
x = self.up_block3(x, t)
if self.is_up:
x = self.up_sample(x)
return x
2.2.9 U-Net
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class UNet(nn.Module):
def __init__(self, img_channels: int = 3, proj_channels: int = 64,
n_groups: int = 32, dropout_rate: float = 0.1, n_heads: int = 1,
has_down_attention: Union[Tuple[bool, ...], List[bool, ]] = [False, False, True, True],
has_up_attention: Union[Tuple[bool, ...], List[bool, ]] = [True, True, False, False],
is_down: Union[Tuple[bool, ...], List[bool, ]] = [False, True, True, True],
is_up: Union[Tuple[bool, ...], List[bool, ]] = [True, True, True, False]):
"""
Appendix B. Our neural network architecture follows the backbone of PixelCNN++, which is a U-Net based on a Wide ResNet.
"""
super().__init__()
# 이해를 위해 필요한 parameter를 4개 생성. But, U-Net 구조 때문에 대칭이기 때문에 두 개만 사용해도 됨
assert len(has_down_attention) == len(has_up_attention) == len(is_down) == len(is_up)
assert has_down_attention == has_up_attention[::-1]
assert is_down == is_up[::-1]
# Setting
time_channels = proj_channels * 4
# Time Ebedding
self.time_embedding = TimeEmbedding(time_channels)
# Image projection 1
self.proj1 = nn.Conv2d(img_channels, proj_channels, kernel_size=3, stride=1, padding=1) # Channels : 3 -> 64
# Down
self.down_0 = Down(64, 64, has_down_attention[0], time_channels, n_groups, dropout_rate, n_heads, is_down[0]) # (-1, 64, 32, 32) -> (-1, 64, 32, 32)
self.down_1 = Down(64, 128, has_down_attention[1], time_channels, n_groups, dropout_rate, n_heads, is_down[1]) # (-1, 64, 32, 32) -> (-1, 128, 16, 16)
self.down_2 = Down(128, 256, has_down_attention[2], time_channels, n_groups, dropout_rate, n_heads, is_down[2]) # (-1, 128, 16, 16) -> (-1, 256, 8, 8)
self.down_3 = Down(256, 1024, has_down_attention[3], time_channels, n_groups, dropout_rate, n_heads, is_down[3]) # (-1, 256, 8, 8) -> (-1, 1024, 4, 4)
# Middle
self.middle = Middle(1024, 1024, time_channels, n_groups, dropout_rate, n_heads) # (-1, 1024, 4, 4) -> (-1, 1024, 4, 4)
# Up
self.up_0 = Up(1024, 256, has_up_attention[0], time_channels, n_groups, dropout_rate, n_heads, is_up[0]) # (-1, 1024, 4, 4) + down_3 -> (-1, 256, 4, 4)
self.up_1 = Up(256, 128, has_up_attention[1], time_channels, n_groups, dropout_rate, n_heads, is_up[1]) # (-1, 256, 4, 4) + down_2 -> (-1, 128, 8, 8)
self.up_2 = Up(128, 64, has_up_attention[2], time_channels, n_groups, dropout_rate, n_heads, is_up[2]) # (-1, 128, 8, 8) + down_1 -> (-1, 64, 16, 16)
self.up_3 = Up(64, 64, has_up_attention[3], time_channels, n_groups, dropout_rate, n_heads, is_up[3]) # (-1, 64, 16, 16) + down_0 -> (-1, 64, 32, 32)
# Image projection 2
self.proj2 = nn.Conv2d(proj_channels, img_channels, kernel_size=3, stride=1, padding=1) # Channels : 64 -> 3
def forward(self, x: torch.Tensor, t: torch.Tensor):
# Time embdding
t = self.time_embedding(t)
# Image projection 1
x = self.proj1(x)
# Down
x, res0 = self.down_0(x, t)
x, res1 = self.down_1(x, t)
x, res2 = self.down_2(x, t)
x, res3 = self.down_3(x, t)
# Middle
x = self.middle(x, t)
# Up
x = self.up_0(x, t, res3)
x = self.up_1(x, t, res2)
x = self.up_2(x, t, res1)
x = self.up_3(x, t, res0)
# Image projection 2
x = self.proj2(x)
return x
3. DDPM의 Inference
DDPM
클래스의 정의해둔 sampling 함수 p_sample
을 단계적으로 실행해 생성되는 이미지를 확인할 수 있습니다.
먼저 위에서 정의한 U-Net
과 DDPM
을 정의해주고 필요한 파라미터들을 함께 정의해줍니다.
1
2
3
4
5
# Time schedule T
T = 1000
# Devcie setting
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
1
2
3
4
5
# DDPM의 backbone U-Net
backbone = UNet().to(device)
# DDPM
ddpm = DDPM(backbone=backbone, T=T, device=device)
위에서 정의한 모델로 sampling을 진행하는 코드를 아래와 같이 작성하면 생성된 이미지를 볼 수 있습니다.
1
2
3
4
5
6
7
8
9
10
11
12
# Number of sampels
n_samples = 10
# Generate Gaussian noises
xt = torch.randn([n_samples, 3, 32, 32], device=device)
# Generate images
for t in reversed(range(T)):
xt = ddpm.p_sample(xt, xt.new_full((n_samples, ), t, dtype=torch.long))
# Result
result = xt
4. DDPM의 Interpolation
생성 모델의 가장 큰 장점이라 하면 dataset에 없는 이미지를 만들어 내는 것이 아닐까 싶습니다. Dataset 분포에 없는 이미지를 만들기 위해서 interpolation 방법을 사용할 수 있습니다.
Dataset에 존재하는 두 이미지를 $x^1_0$과 $x^2_0$라 하고 두 이미지의 diffusion space에서의 latent를 $x^1_t$와 $x^2_t$라 하면 interpolation은 아래와 같이 할 수 있습니다.
먼저 두 이미지 $x^1_0$와 $x^2_0$를 diffusion space로 보내겠습니다.
\[x^1_t \sim~ q(x^1_t|x^1_0)\qquad \&\qquad x^2_t \sim~ q(x^2_t|x^2_0)\]Diffusion space에서 interpolation을 수행하겠습니다.
\[\bar{x}_t = (1 - \lambda)x^1_t + \lambda x^2_t\]Diffusion space에서 interpolation된 latent를 다시 image space로 보내어 interpolation을 완성합니다.
\[\bar{x}_0 \sim~ p_\theta(\bar{x}_0|\bar{x}_t)\]
이 때 $t$에 어떤 값을 주는지에 따라 interpolation 결과가 달라지며 $t$가 커질수록 더 다양한 interpolation이 되며 따라서 새로운 결과를 많이 확인할 수 있습니다.
※ Inference에서 사용한 변수들을 그대로 사용했습니다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Dataset에 있는 이미지라 가정
x1_0 = torch.randn([n_samples, 3, 32, 32], device=device)
x2_0 = torch.randn([n_samples, 3, 32, 32], device=device)
# Setting interpolation_t (Paper - Fig 8.)
interpolation_t = 500
t = torch.full((n_samples,), interpolation_t, device=device)
# Setting lambda
lambda_ = .5
# Interpolation formula (xt: $\bar{x}_t$)
xt = (1 - lambda_) * ddpm.q_sample(x1_0, t) + lambda_ * ddpm.q_sample(x2_0, t)
# Generate interpolation images
for t_ in reversed(range(interpolation_t)):
xt = ddpm.p_sample(xt, xt.new_full((n_samples, ), t_, dtype=torch.long))
# Result
result = xt
5. Conclusion
Gaussian distribution parameterization과 time scheduler $\beta_t$를 상수로 둠으로써 학습이 더 간단해지고 sampling 하는 과정이 denoising score matching with Langevin dynamics와 유사해짐에 따라 높은 품질의 sample을 생성할 수 있게 되었습니다.
해당 논문에 등장하는 개념이 많아 복잡하기는 하지만 하나씩 이해하다 보면 좋은 공부가 되는 것 같습니다.
궁금한 점으로는 sampling을 더 단순화 할 수 있지는 않을까 싶습니다. 이러한 내용에 대해 많은 연구가 진행되었다고 하니 추후에 공부를 해볼 예정입니다.