今回は PyTorch を用いて, Ho et. al. [NeurIPS 33(2020)] による DDPM (Denoising Diffusion Probabilistic Model) の実装の概要を見る.DDPM は拡散模型の最初の例の1つであり,ノイズからデータ分布まで到達するフローを定める拡散過程(雑音除去過程)を,データをノイズにする拡散過程の時間反転として学習する方法である.画像や動画だけでなく,離散空間上でタンパク質などの構造生成でも state of the art の性能を示すモデルである.
A Blog Entry on Bayesian Computation by an Applied Mathematician
class ConvBlock(nn.Conv2d):""" Conv2D Block Args: x: (N, C_in, H, W) Returns: y: (N, C_out, H, W) """def__init__(self, in_channels, out_channels, kernel_size, activation_fn=None, drop_rate=0., stride=1, padding='same', dilation=1, groups=1, bias=True, gn=False, gn_groups=8):if padding =='same': padding = kernel_size //2* dilationsuper(ConvBlock, self).__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)self.activation_fn = nn.SiLU() if activation_fn elseNoneself.group_norm = nn.GroupNorm(gn_groups, out_channels) if gn elseNonedef forward(self, x, time_embedding=None, residual=False):if residual:# in the paper, diffusion timestep embedding was only applied to residual blocks of U-Net x = x + time_embedding y = x x =super(ConvBlock, self).forward(x) y = y + xelse: y =super(ConvBlock, self).forward(x) y =self.group_norm(y) ifself.group_norm isnotNoneelse y y =self.activation_fn(y) ifself.activation_fn isnotNoneelse yreturn y
class Diffusion(nn.Module):def__init__(self, model, image_resolution=[32, 32, 3], n_times=1000, beta_minmax=[1e-4, 2e-2], device='cuda'):super(Diffusion, self).__init__()self.n_times = n_timesself.img_H, self.img_W, self.img_C = image_resolutionself.model = model# define linear variance schedule(betas) beta_1, beta_T = beta_minmax betas = torch.linspace(start=beta_1, end=beta_T, steps=n_times).to(device) # follows DDPM paperself.sqrt_betas = torch.sqrt(betas)# define alpha for forward diffusion kernelself.alphas =1- betasself.sqrt_alphas = torch.sqrt(self.alphas) alpha_bars = torch.cumprod(self.alphas, dim=0)self.sqrt_one_minus_alpha_bars = torch.sqrt(1-alpha_bars)self.sqrt_alpha_bars = torch.sqrt(alpha_bars)self.device = devicedef extract(self, a, t, x_shape):""" from lucidrains' implementation https://github.com/lucidrains/denoising-diffusion-pytorch/blob/beb2f2d8dd9b4f2bd5be4719f37082fe061ee450/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L376 """ b, *_ = t.shape out = a.gather(-1, t)return out.reshape(b, *((1,) * (len(x_shape) -1)))def scale_to_minus_one_to_one(self, x):# according to the DDPMs paper, normalization seems to be crucial to train reverse process networkreturn x *2-1def reverse_scale_to_zero_to_one(self, x):return (x +1) *0.5def make_noisy(self, x_zeros, t): # perturb x_0 into x_t (i.e., take x_0 samples into forward diffusion kernels) epsilon = torch.randn_like(x_zeros).to(self.device) sqrt_alpha_bar =self.extract(self.sqrt_alpha_bars, t, x_zeros.shape) sqrt_one_minus_alpha_bar =self.extract(self.sqrt_one_minus_alpha_bars, t, x_zeros.shape)# Let's make noisy sample!: i.e., Forward process with fixed variance schedule# i.e., sqrt(alpha_bar_t) * x_zero + sqrt(1-alpha_bar_t) * epsilon noisy_sample = x_zeros * sqrt_alpha_bar + epsilon * sqrt_one_minus_alpha_barreturn noisy_sample.detach(), epsilondef forward(self, x_zeros): x_zeros =self.scale_to_minus_one_to_one(x_zeros) B, _, _, _ = x_zeros.shape# (1) randomly choose diffusion time-step t = torch.randint(low=0, high=self.n_times, size=(B,)).long().to(self.device)# (2) forward diffusion process: perturb x_zeros with fixed variance schedule perturbed_images, epsilon =self.make_noisy(x_zeros, t)# (3) predict epsilon(noise) given perturbed data at diffusion-timestep t. pred_epsilon =self.model(perturbed_images, t)return perturbed_images, epsilon, pred_epsilondef denoise_at_t(self, x_t, timestep, t): B, _, _, _ = x_t.shapeif t >1: z = torch.randn_like(x_t).to(self.device)else: z = torch.zeros_like(x_t).to(self.device)# at inference, we use predicted noise(epsilon) to restore perturbed data sample. epsilon_pred =self.model(x_t, timestep) alpha =self.extract(self.alphas, timestep, x_t.shape) sqrt_alpha =self.extract(self.sqrt_alphas, timestep, x_t.shape) sqrt_one_minus_alpha_bar =self.extract(self.sqrt_one_minus_alpha_bars, timestep, x_t.shape) sqrt_beta =self.extract(self.sqrt_betas, timestep, x_t.shape)# denoise at time t, utilizing predicted noise x_t_minus_1 =1/ sqrt_alpha * (x_t - (1-alpha)/sqrt_one_minus_alpha_bar*epsilon_pred) + sqrt_beta*zreturn x_t_minus_1.clamp(-1., 1)def sample(self, N):# start from random noise vector, x_0 (for simplicity, x_T declared as x_t instead of x_T) x_t = torch.randn((N, self.img_C, self.img_H, self.img_W)).to(self.device)# autoregressively denoise from x_T to x_0# i.e., generate image from noise, x_Tfor t inrange(self.n_times-1, -1, -1): timestep = torch.tensor([t]).repeat_interleave(N, dim=0).long().to(self.device) x_t =self.denoise_at_t(x_t, timestep, t)# denormalize x_0 into 0 ~ 1 ranged values. x_0 =self.reverse_scale_to_zero_to_one(x_t)return x_0diffusion = Diffusion(model, image_resolution=img_size, n_times=n_timesteps, beta_minmax=beta_minmax, device=DEVICE).to(DEVICE)optimizer = Adam(diffusion.parameters(), lr=lr)denoising_loss = nn.MSELoss()
2.5 エンコーディングの様子
def count_parameters(model):returnsum(p.numel() for p in model.parameters() if p.requires_grad)print("Number of model parameters: ", count_parameters(diffusion))
/var/folders/gx/6w78f6997l5___173r25fp3m0000gn/T/ipykernel_6810/3728020930.py:3: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
generated_images = torch.load("Files/generated_images1.pt", map_location=torch.device('cpu'))
Ronneberger, O., Fischer, P., and Brox, T. (2015). U-net: Convolutional networks for biomedical image segmentation. In N. Navab, J. Hornegger, W. M. Wells, and A. F. Frangi, editors, Medical image computing and computer-assisted intervention – MICCAI 2015, pages 234–241. Cham: Springer International Publishing.
Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., … Polosukhin, I. (2017). Attention is all you need. In Advances in neural information processing systems,Vol. 30.