import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from PIL import Image
import glob
from torchvision.transforms import ToTensor, PILToTensor
import os
import pandas as pd
from torchvision.io import read_image
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import torch.nn.functional as F
import time
from pathlib import Path
from torchvision.datasets import ImageFolder
from tqdm.auto import tqdm
import matplotlib.colors as mcolors
import numpy as np
from torchinfo import summary
Variational Autoencoder
Python
Deep Learning
NLP
Part 1 of 3 in a series exploring generative models.
= '{"username":"jakegehri","key":"41ff660454233a25221965d475c8d92b"}' creds
= Path('~/.kaggle/kaggle.json').expanduser()
cred_path
if not cred_path.exists():
=True)
cred_path.parent.mkdir(exist_ok
cred_path.write_text(creds)0o600) cred_path.chmod(
import kaggle
= 'cuda' DEVICE
Load and prepare data
= transforms.Compose([transforms.Grayscale(num_output_channels=1),
transform 64),
transforms.Resize(64),
transforms.CenterCrop( transforms.ToTensor()])
= ImageFolder(root = 'shapes', transform=transform) dataset
0][0].shape dataset[
torch.Size([1, 64, 64])
0][0].squeeze(0)) plt.imshow(dataset[
<matplotlib.image.AxesImage at 0x7fa2724ef670>
= 32 batch_size
= DataLoader(dataset, batch_size=batch_size, shuffle=True) dataloader
def show_images(images, nmax=64):
= plt.subplots(figsize=(8, 8))
fig, ax ; ax.set_yticks([])
ax.set_xticks([])=8).permute(1, 2, 0))
ax.imshow(make_grid((images.detach()[:nmax]), nrowdef show_batch(dl, nmax=64):
for images in dl:
0], nmax)
show_images(images[break
show_batch(dataloader)
Build Autoencoder
class Encoder(nn.Module):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
1, 32, stride=2, kernel_size=3, padding=1), # 64 -> 32
nn.Conv2d(32),
nn.BatchNorm2d(0.1, inplace=True),
nn.LeakyReLU(0.25),
nn.Dropout2d(32, 32, kernel_size=3, padding=1),
nn.Conv2d(32),
nn.BatchNorm2d(0.1, inplace=True),
nn.LeakyReLU(0.25),
nn.Dropout2d(32, 32, kernel_size=3, padding=1),
nn.Conv2d(32),
nn.BatchNorm2d(0.1, inplace=True),
nn.LeakyReLU(0.25),
nn.Dropout2d(32, 64, stride=2, kernel_size=3, padding=1), # 32 -> 16
nn.Conv2d(64),
nn.BatchNorm2d(0.1, inplace=True),
nn.LeakyReLU(0.25),
nn.Dropout2d(64, 64, kernel_size=3, padding=1),
nn.Conv2d(64),
nn.BatchNorm2d(0.1, inplace=True),
nn.LeakyReLU(0.25),
nn.Dropout2d(64, 64, stride=2 ,kernel_size=3, padding=1), #16 -> 8
nn.Conv2d(64),
nn.BatchNorm2d(0.1, inplace=True),
nn.LeakyReLU(0.25),
nn.Dropout2d(
nn.Flatten(),4096, 2)
nn.Linear(
)
def forward(self, x):
return self.encoder(x)
= next(iter(dataloader))[0]
X = Encoder()
enc = enc(X)
z X.shape, z.shape
(torch.Size([32, 1, 64, 64]), torch.Size([32, 2]))
class Reshape(nn.Module):
def __init__(self, *args):
super().__init__()
self.shape = args
def forward(self, x):
return x.view(self.shape)
class Trim(nn.Module):
def __init__(self, *args):
super().__init__()
def forward(self, x):
return x[:, :, :64, :64]
class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Sequential(
2, 4096),
nn.Linear(0.01),
nn.LeakyReLU(
)
self.decoder = nn.Sequential(
64, 64, stride=2, kernel_size=3, padding=1, output_padding=1),
nn.ConvTranspose2d(64),
nn.BatchNorm2d(0.1, inplace=True),
nn.LeakyReLU(0.25),
nn.Dropout2d(64, 64, stride=1, kernel_size=3, padding=1),
nn.ConvTranspose2d(64),
nn.BatchNorm2d(0.1, inplace=True),
nn.LeakyReLU(0.25),
nn.Dropout2d(64, 32, stride=2, kernel_size=3, padding=1, output_padding=1),
nn.ConvTranspose2d(32),
nn.BatchNorm2d(0.1, inplace=True),
nn.LeakyReLU(0.25),
nn.Dropout2d(32, 32, stride=1, kernel_size=3, padding=1),
nn.ConvTranspose2d(32),
nn.BatchNorm2d(0.1, inplace=True),
nn.LeakyReLU(0.25),
nn.Dropout2d(32, 32, stride=1, kernel_size=3, padding=1),
nn.ConvTranspose2d(32),
nn.BatchNorm2d(0.1, inplace=True),
nn.LeakyReLU(0.25),
nn.Dropout2d(32, 1, stride = 2, kernel_size=3, padding=1, output_padding=1),
nn.ConvTranspose2d(
Trim(),
nn.Sigmoid()
)
def forward(self, z):
= self.linear(z)
z = torch.reshape(z, (-1, 64, 8, 8))
z = self.decoder(z)
z return z
= Decoder()
decoder = decoder(z) X_prime
X.shape, X_prime.shape
(torch.Size([32, 1, 64, 64]), torch.Size([32, 1, 64, 64]))
0].cpu().permute(1, 2, 0).detach().numpy()) plt.imshow(X_prime[
<matplotlib.image.AxesImage at 0x7fa272323670>
class Autoencoder(nn.Module):
def __init__(self):
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def forward(self, x):
return self.decoder(self.encoder(x))
= Autoencoder().to(DEVICE)
ae = F.mse_loss
loss_fn = torch.optim.Adam(params=ae.parameters(), lr=0.0005) optimizer
1, 64, 64)) summary(ae, (batch_size,
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
Autoencoder [32, 1, 64, 64] --
├─Encoder: 1-1 [32, 2] --
│ └─Sequential: 2-1 [32, 2] --
│ │ └─Conv2d: 3-1 [32, 32, 32, 32] 320
│ │ └─BatchNorm2d: 3-2 [32, 32, 32, 32] 64
│ │ └─LeakyReLU: 3-3 [32, 32, 32, 32] --
│ │ └─Dropout2d: 3-4 [32, 32, 32, 32] --
│ │ └─Conv2d: 3-5 [32, 32, 32, 32] 9,248
│ │ └─BatchNorm2d: 3-6 [32, 32, 32, 32] 64
│ │ └─LeakyReLU: 3-7 [32, 32, 32, 32] --
│ │ └─Dropout2d: 3-8 [32, 32, 32, 32] --
│ │ └─Conv2d: 3-9 [32, 32, 32, 32] 9,248
│ │ └─BatchNorm2d: 3-10 [32, 32, 32, 32] 64
│ │ └─LeakyReLU: 3-11 [32, 32, 32, 32] --
│ │ └─Dropout2d: 3-12 [32, 32, 32, 32] --
│ │ └─Conv2d: 3-13 [32, 64, 16, 16] 18,496
│ │ └─BatchNorm2d: 3-14 [32, 64, 16, 16] 128
│ │ └─LeakyReLU: 3-15 [32, 64, 16, 16] --
│ │ └─Dropout2d: 3-16 [32, 64, 16, 16] --
│ │ └─Conv2d: 3-17 [32, 64, 16, 16] 36,928
│ │ └─BatchNorm2d: 3-18 [32, 64, 16, 16] 128
│ │ └─LeakyReLU: 3-19 [32, 64, 16, 16] --
│ │ └─Dropout2d: 3-20 [32, 64, 16, 16] --
│ │ └─Conv2d: 3-21 [32, 64, 8, 8] 36,928
│ │ └─BatchNorm2d: 3-22 [32, 64, 8, 8] 128
│ │ └─LeakyReLU: 3-23 [32, 64, 8, 8] --
│ │ └─Dropout2d: 3-24 [32, 64, 8, 8] --
│ │ └─Flatten: 3-25 [32, 4096] --
│ │ └─Linear: 3-26 [32, 2] 8,194
├─Decoder: 1-2 [32, 1, 64, 64] --
│ └─Sequential: 2-2 [32, 4096] --
│ │ └─Linear: 3-27 [32, 4096] 12,288
│ │ └─LeakyReLU: 3-28 [32, 4096] --
│ └─Sequential: 2-3 [32, 1, 64, 64] --
│ │ └─ConvTranspose2d: 3-29 [32, 64, 16, 16] 36,928
│ │ └─BatchNorm2d: 3-30 [32, 64, 16, 16] 128
│ │ └─LeakyReLU: 3-31 [32, 64, 16, 16] --
│ │ └─Dropout2d: 3-32 [32, 64, 16, 16] --
│ │ └─ConvTranspose2d: 3-33 [32, 64, 16, 16] 36,928
│ │ └─BatchNorm2d: 3-34 [32, 64, 16, 16] 128
│ │ └─LeakyReLU: 3-35 [32, 64, 16, 16] --
│ │ └─Dropout2d: 3-36 [32, 64, 16, 16] --
│ │ └─ConvTranspose2d: 3-37 [32, 32, 32, 32] 18,464
│ │ └─BatchNorm2d: 3-38 [32, 32, 32, 32] 64
│ │ └─LeakyReLU: 3-39 [32, 32, 32, 32] --
│ │ └─Dropout2d: 3-40 [32, 32, 32, 32] --
│ │ └─ConvTranspose2d: 3-41 [32, 32, 32, 32] 9,248
│ │ └─BatchNorm2d: 3-42 [32, 32, 32, 32] 64
│ │ └─LeakyReLU: 3-43 [32, 32, 32, 32] --
│ │ └─Dropout2d: 3-44 [32, 32, 32, 32] --
│ │ └─ConvTranspose2d: 3-45 [32, 32, 32, 32] 9,248
│ │ └─BatchNorm2d: 3-46 [32, 32, 32, 32] 64
│ │ └─LeakyReLU: 3-47 [32, 32, 32, 32] --
│ │ └─Dropout2d: 3-48 [32, 32, 32, 32] --
│ │ └─ConvTranspose2d: 3-49 [32, 1, 64, 64] 289
│ │ └─Trim: 3-50 [32, 1, 64, 64] --
│ │ └─Sigmoid: 3-51 [32, 1, 64, 64] --
==========================================================================================
Total params: 243,779
Trainable params: 243,779
Non-trainable params: 0
Total mult-adds (G): 3.00
==========================================================================================
Input size (MB): 0.52
Forward/backward pass size (MB): 138.41
Params size (MB): 0.98
Estimated Total Size (MB): 139.91
==========================================================================================
= 50
epochs = []
outputs = []
losses for epoch in tqdm(range(epochs)):
for batch, (image, _) in tqdm(enumerate(dataloader)):
= image.to(DEVICE)
image
= ae(image)
reconstructed
= loss_fn(reconstructed, image, reduction = 'none')
loss = loss.view(batch_size, -1).sum(axis=1)
loss = loss.mean()
loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.detach().cpu().numpy()) outputs.append((epoch, image, reconstructed))
plt.plot(losses)
0].reshape(-1, 64, 64).cpu().permute(1, 2, 0)) plt.imshow(image[
<matplotlib.image.AxesImage at 0x7fa265c14a60>
0].detach().reshape(-1, 64, 64).cpu().permute(1, 2, 0)) plt.imshow(reconstructed[
<matplotlib.image.AxesImage at 0x7fa265bd6340>
def plot_latent_space_with_labels(num_classes, dataloader, model, device):
= {i:[] for i in range(num_classes)}
d
eval()
model.with torch.no_grad():
for i, (features, targets) in enumerate(dataloader):
= features.to(device)
features = targets.to(device)
targets
= model.encoder(features)
embedding
for i in range(num_classes):
if i in targets:
= targets == i
mask 'cpu').numpy())
d[i].append(embedding[mask].to(
= list(mcolors.TABLEAU_COLORS.items())
colors for i in range(num_classes):
= np.concatenate(d[i])
d[i]
plt.scatter(0], d[i][:, 1],
d[i][:, =colors[i][1],
color=f'{i}',
label=0.5)
alpha
plt.legend()
plot_latent_space_with_labels(=3,
num_classes=dataloader,
dataloader=ae,
model=DEVICE)
device
plt.legend() plt.show()
Beta - Variational Autoencoder
class VAE(nn.Module):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
1, 32, stride=2, kernel_size=3, padding=1), # 64 -> 32
nn.Conv2d(32),
nn.BatchNorm2d(0.1, inplace=True),
nn.LeakyReLU(0.25),
nn.Dropout2d(32, 32, kernel_size=3, padding=1),
nn.Conv2d(32),
nn.BatchNorm2d(0.1, inplace=True),
nn.LeakyReLU(0.25),
nn.Dropout2d(32, 32, kernel_size=3, padding=1),
nn.Conv2d(32),
nn.BatchNorm2d(0.1, inplace=True),
nn.LeakyReLU(0.25),
nn.Dropout2d(32, 64, stride=2, kernel_size=3, padding=1), # 32 -> 16
nn.Conv2d(64),
nn.BatchNorm2d(0.1, inplace=True),
nn.LeakyReLU(0.25),
nn.Dropout2d(64, 64, kernel_size=3, padding=1),
nn.Conv2d(64),
nn.BatchNorm2d(0.1, inplace=True),
nn.LeakyReLU(0.25),
nn.Dropout2d(64, 64, stride=2 ,kernel_size=3, padding=1), #16 -> 8
nn.Conv2d(64),
nn.BatchNorm2d(0.1, inplace=True),
nn.LeakyReLU(0.25),
nn.Dropout2d(
nn.Flatten(),
)
self.z_mean = nn.Linear(4096, 2)
self.z_log_var = nn.Linear(4096, 2)
self.decoder = nn.Sequential(
2, 4096),
nn.Linear(-1, 64, 8, 8),
Reshape(64, 64, stride=2, kernel_size=3, padding=1, output_padding=1),
nn.ConvTranspose2d(64),
nn.BatchNorm2d(0.1, inplace=True),
nn.LeakyReLU(0.25),
nn.Dropout2d(64, 64, stride=1, kernel_size=3, padding=1),
nn.ConvTranspose2d(64),
nn.BatchNorm2d(0.1, inplace=True),
nn.LeakyReLU(0.25),
nn.Dropout2d(64, 32, stride=2, kernel_size=3, padding=1, output_padding=1),
nn.ConvTranspose2d(32),
nn.BatchNorm2d(0.1, inplace=True),
nn.LeakyReLU(0.25),
nn.Dropout2d(32, 32, stride=1, kernel_size=3, padding=1),
nn.ConvTranspose2d(32),
nn.BatchNorm2d(0.1, inplace=True),
nn.LeakyReLU(0.25),
nn.Dropout2d(32, 32, stride=1, kernel_size=3, padding=1),
nn.ConvTranspose2d(32),
nn.BatchNorm2d(0.1, inplace=True),
nn.LeakyReLU(0.25),
nn.Dropout2d(32, 1, stride = 2, kernel_size=3, padding=1, output_padding=1),
nn.ConvTranspose2d(#Trim(),
nn.Sigmoid()
)
def reparameterize(self, z_mu, z_log_var):
= torch.randn(z_mu.size(0), z_mu.size(1)).to(z_mu.get_device())
eps = z_mu + eps * torch.exp(z_log_var/2.)
z return z
def encoding_fn(self, x):
= self.encoder(x)
x = self.z_mean(x), self.z_log_var(x)
z_mean, z_log_var = self.reparameterize(z_mean, z_log_var)
encoded return encoded
def forward(self, x):
= self.encoder(x)
x = self.z_mean(x), self.z_log_var(x)
z_mean, z_log_var = self.reparameterize(z_mean, z_log_var)
encoded = self.decoder(encoded)
decoded return encoded, z_mean, z_log_var, decoded
= VAE()
vae
vae.to(DEVICE)= F.mse_loss
reconstruction_loss_fn
= torch.optim.Adam(params=vae.parameters(), lr=0.00005) optimizer
32, 1, 64, 64)) summary(vae, (
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
VAE [32, 2] --
├─Sequential: 1-1 [32, 4096] --
│ └─Conv2d: 2-1 [32, 32, 32, 32] 320
│ └─BatchNorm2d: 2-2 [32, 32, 32, 32] 64
│ └─LeakyReLU: 2-3 [32, 32, 32, 32] --
│ └─Dropout2d: 2-4 [32, 32, 32, 32] --
│ └─Conv2d: 2-5 [32, 32, 32, 32] 9,248
│ └─BatchNorm2d: 2-6 [32, 32, 32, 32] 64
│ └─LeakyReLU: 2-7 [32, 32, 32, 32] --
│ └─Dropout2d: 2-8 [32, 32, 32, 32] --
│ └─Conv2d: 2-9 [32, 32, 32, 32] 9,248
│ └─BatchNorm2d: 2-10 [32, 32, 32, 32] 64
│ └─LeakyReLU: 2-11 [32, 32, 32, 32] --
│ └─Dropout2d: 2-12 [32, 32, 32, 32] --
│ └─Conv2d: 2-13 [32, 64, 16, 16] 18,496
│ └─BatchNorm2d: 2-14 [32, 64, 16, 16] 128
│ └─LeakyReLU: 2-15 [32, 64, 16, 16] --
│ └─Dropout2d: 2-16 [32, 64, 16, 16] --
│ └─Conv2d: 2-17 [32, 64, 16, 16] 36,928
│ └─BatchNorm2d: 2-18 [32, 64, 16, 16] 128
│ └─LeakyReLU: 2-19 [32, 64, 16, 16] --
│ └─Dropout2d: 2-20 [32, 64, 16, 16] --
│ └─Conv2d: 2-21 [32, 64, 8, 8] 36,928
│ └─BatchNorm2d: 2-22 [32, 64, 8, 8] 128
│ └─LeakyReLU: 2-23 [32, 64, 8, 8] --
│ └─Dropout2d: 2-24 [32, 64, 8, 8] --
│ └─Flatten: 2-25 [32, 4096] --
├─Linear: 1-2 [32, 2] 8,194
├─Linear: 1-3 [32, 2] 8,194
├─Sequential: 1-4 [32, 1, 64, 64] --
│ └─Linear: 2-26 [32, 4096] 12,288
│ └─Reshape: 2-27 [32, 64, 8, 8] --
│ └─ConvTranspose2d: 2-28 [32, 64, 16, 16] 36,928
│ └─BatchNorm2d: 2-29 [32, 64, 16, 16] 128
│ └─LeakyReLU: 2-30 [32, 64, 16, 16] --
│ └─Dropout2d: 2-31 [32, 64, 16, 16] --
│ └─ConvTranspose2d: 2-32 [32, 64, 16, 16] 36,928
│ └─BatchNorm2d: 2-33 [32, 64, 16, 16] 128
│ └─LeakyReLU: 2-34 [32, 64, 16, 16] --
│ └─Dropout2d: 2-35 [32, 64, 16, 16] --
│ └─ConvTranspose2d: 2-36 [32, 32, 32, 32] 18,464
│ └─BatchNorm2d: 2-37 [32, 32, 32, 32] 64
│ └─LeakyReLU: 2-38 [32, 32, 32, 32] --
│ └─Dropout2d: 2-39 [32, 32, 32, 32] --
│ └─ConvTranspose2d: 2-40 [32, 32, 32, 32] 9,248
│ └─BatchNorm2d: 2-41 [32, 32, 32, 32] 64
│ └─LeakyReLU: 2-42 [32, 32, 32, 32] --
│ └─Dropout2d: 2-43 [32, 32, 32, 32] --
│ └─ConvTranspose2d: 2-44 [32, 32, 32, 32] 9,248
│ └─BatchNorm2d: 2-45 [32, 32, 32, 32] 64
│ └─LeakyReLU: 2-46 [32, 32, 32, 32] --
│ └─Dropout2d: 2-47 [32, 32, 32, 32] --
│ └─ConvTranspose2d: 2-48 [32, 1, 64, 64] 289
│ └─Sigmoid: 2-49 [32, 1, 64, 64] --
==========================================================================================
Total params: 251,973
Trainable params: 251,973
Non-trainable params: 0
Total mult-adds (G): 3.00
==========================================================================================
Input size (MB): 0.52
Forward/backward pass size (MB): 138.41
Params size (MB): 1.01
Estimated Total Size (MB): 139.95
==========================================================================================
= 50
epochs = 6
beta = []
vae_outputs = []
combined_losses = []
reconstruction_losses = []
kl_losses
for epoch in tqdm(range(epochs)):
for batch, (image, _) in tqdm(enumerate(dataloader)):
= image.to(DEVICE)
image
= vae(image)
encoded, z_mean, z_var, decoded
= -0.5 * torch.sum(1 + z_var - z_mean**2 - torch.exp(z_var), axis=1)
kl_div = kl_div.size(0)
batch_size = kl_div.mean()
kl_div
= reconstruction_loss_fn(decoded, image, reduction = 'none')
reconstruction_loss = reconstruction_loss.view(batch_size, -1).sum(axis=1)
reconstruction_loss = reconstruction_loss.mean()
reconstruction_loss
= reconstruction_loss + beta*kl_div
loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
combined_losses.append(loss.detach().cpu().numpy())
reconstruction_losses.append(reconstruction_loss.detach().cpu().numpy())
kl_losses.append(kl_div.detach().cpu().numpy()) vae_outputs.append((epoch, image, decoded))
plt.plot(combined_losses)
plt.plot(reconstruction_losses)
plt.plot(kl_losses)
def plot_latent_space_with_labels(num_classes, dataloader, encoding_fn, device):
= {i:[] for i in range(num_classes)}
d
with torch.no_grad():
for i, (features, targets) in enumerate(dataloader):
= features.to(device)
features = targets.to(device)
targets
= encoding_fn(features)
embedding
for i in range(num_classes):
if i in targets:
= targets == i
mask 'cpu').numpy())
d[i].append(embedding[mask].to(
= list(mcolors.TABLEAU_COLORS.items())
colors for i in range(num_classes):
= np.concatenate(d[i])
d[i]
plt.scatter(0], d[i][:, 1],
d[i][:, =colors[i][1],
color=f'{i}',
label=0.5)
alpha
plt.legend()
plot_latent_space_with_labels(=3,
num_classes=dataloader,
dataloader=vae.encoding_fn,
encoding_fn=DEVICE)
device
plt.legend() plt.show()
= torch.arange(-1.2, 2, 0.2)
y_axis = torch.zeros(16)
x_axis = torch.stack((x_axis, y_axis), dim=1).to(DEVICE) samples
samples
tensor([[ 0.0000e+00, -1.2000e+00],
[ 0.0000e+00, -1.0000e+00],
[ 0.0000e+00, -8.0000e-01],
[ 0.0000e+00, -6.0000e-01],
[ 0.0000e+00, -4.0000e-01],
[ 0.0000e+00, -2.0000e-01],
[ 0.0000e+00, -4.7684e-08],
[ 0.0000e+00, 2.0000e-01],
[ 0.0000e+00, 4.0000e-01],
[ 0.0000e+00, 6.0000e-01],
[ 0.0000e+00, 8.0000e-01],
[ 0.0000e+00, 1.0000e+00],
[ 0.0000e+00, 1.2000e+00],
[ 0.0000e+00, 1.4000e+00],
[ 0.0000e+00, 1.6000e+00],
[ 0.0000e+00, 1.8000e+00]], device='cuda:0')
def show_images(images, nmax=64):
= plt.subplots(figsize=(8, 8))
fig, ax ; ax.set_yticks([])
ax.set_xticks([])=8).permute(1, 2, 0)) ax.imshow(make_grid((images.detach()[:nmax]), nrow
= []
img_smaples for sample in samples:
eval()
vae.with torch.no_grad():
= vae.decoder(sample)
sample_out 0).detach().cpu()) img_smaples.append(sample_out.squeeze(
= make_grid(img_smaples) grid
def show(imgs):
if not isinstance(imgs, list):
= [imgs]
imgs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(12,12))
fig, axs for i, img in enumerate(imgs):
= img.detach()
img = transforms.functional.to_pil_image(img)
img 0].imshow(np.asarray(img))
axs[i, 0].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) axs[i,
show(grid)
eval()
vae.with torch.no_grad():
= vae.decoder(torch.tensor([0., 0.8]).cuda())
sample_out -1, 64, 64).cpu().permute(1, 2, 0)) plt.imshow(sample_out.detach().reshape(