import torch
import torchvision
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from torch import nn
from tqdm.auto import tqdm
from torchmetrics import Accuracy
from torchvision import transforms
from torchinfo import summary
Paper Replication: Vision Transformer
Python
Deep Learning
Computer Vision
The first in a series of paper replications that will collectively lay the foundation for a full replication of the stable diffusion model from scratch.
= 1000
train_split_len = torchvision.datasets.CIFAR10(root = 'data', train = True, download=True, transform = ToTensor())
trainset = torch.utils.data.random_split(trainset, [train_split_len, len(trainset)-train_split_len])[0]
part_train = DataLoader(part_train, batch_size=25, shuffle=True) trainloader
Files already downloaded and verified
= 200
test_split_len = torchvision.datasets.CIFAR10(root = 'data', train = False, download=True, transform = ToTensor())
testset = torch.utils.data.random_split(testset, [test_split_len, len(testset)-test_split_len])[0]
part_test = DataLoader(part_test, batch_size=25, shuffle=True) testloader
Files already downloaded and verified
= testset.class_to_idx
classes classes
{'airplane': 0,
'automobile': 1,
'bird': 2,
'cat': 3,
'deer': 4,
'dog': 5,
'frog': 6,
'horse': 7,
'ship': 8,
'truck': 9}
= iter(trainloader)
dataiter = next(dataiter) images, labels
images.shape, labels.shape
(torch.Size([25, 3, 32, 32]), torch.Size([25]))
print(list(classes.keys()) [list(classes.values()).index(1)])
1].permute(1,2,0).numpy()) plt.imshow(images[
automobile
<matplotlib.image.AxesImage at 0x7f28e9e29700>
class LinearProjectionBlock(nn.Module):
def __init__(self, in_channels = 3, patch_size = 2, embedding_dim = 768, batch_size=25, num_patches=256):
super().__init__()
self.patch = nn.Conv2d(in_channels=in_channels, out_channels=embedding_dim, kernel_size=patch_size, stride=patch_size)
self.flatten = nn.Flatten(start_dim=2, end_dim=3)
self.class_token = nn.Parameter(torch.rand(batch_size, 1, embedding_dim), requires_grad=True)
self.positional = nn.Parameter(torch.rand(batch_size, num_patches+1, embedding_dim), requires_grad=True)
def forward(self, x):
= self.patch(x)
x = self.flatten(x)
x = x.permute(0, 2, 1)
x = torch.cat((self.class_token, x), dim=1)
x = self.positional + x
x return x
class MSABlock(nn.Module):
def __init__(self, embedding_dim = 768, num_heads=12):
super().__init__()
self.normalize = nn.LayerNorm(normalized_shape=embedding_dim)
self.msa = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads = num_heads, batch_first=True)
def forward(self, x):
= self.normalize(x)
x = self.msa(query=x, key=x, value=x, need_weights=False)
x, _ return x
= LinearProjectionBlock(in_channels=3, patch_size=2, embedding_dim=768, batch_size=25, num_patches=256) lin_projs
= lin_projs(images) projs
projs.shape
torch.Size([25, 257, 768])
= MSABlock(embedding_dim=768, num_heads=12) msa_block
= msa_block(projs) out
out.shape
torch.Size([25, 257, 768])
class MLPBlock(nn.Module):
def __init__(self, embedding_dim = 768, mlp_size = 3072):
super().__init__()
self.normalize = nn.LayerNorm(embedding_dim)
self.mlp = nn.Sequential(
=embedding_dim, out_features=mlp_size),
nn.Linear(in_features
nn.GELU(),0.1),
nn.Dropout(=mlp_size, out_features=embedding_dim),
nn.Linear(in_features0.1)
nn.Dropout(
)
def forward(self, x):
= self.normalize(x)
x = self.mlp(x)
x return x
class TrasformerEncoderBlock(nn.Module):
def __init__(self, embedding_dim = 768, num_heads=12, mlp_size=3072):
super().__init__()
self.msa_block = MSABlock(embedding_dim = embedding_dim, num_heads=num_heads)
self.mlp_block = MLPBlock(embedding_dim = embedding_dim, mlp_size = mlp_size)
def forward(self, x):
= self.msa_block(x)
x
= self.mlp_block(x)
x
return x
= MLPBlock(embedding_dim=768) mlp
= mlp(out) fout
class ViT(nn.Module):
def __init__(self, img_size = 32, in_channels = 3, patch_size = 2, embedding_dim = 768,
=25, num_patches=256, mlp_size = 3072, layers = 12, num_heads=12, num_classes=10):
batch_sizesuper().__init__()
self.projections = LinearProjectionBlock(in_channels=in_channels, patch_size=patch_size, embedding_dim=embedding_dim, batch_size=batch_size, num_patches=num_patches)
self.transformer_encoder = nn.TransformerEncoderLayer(d_model=embedding_dim,
=num_heads,
nhead=mlp_size,
dim_feedforward=0.1,
dropout="gelu",
activation=True,
batch_first=True)
norm_firstself.encoder_layer = nn.TransformerEncoder(self.transformer_encoder, num_layers=layers)
self.classifier = nn.Sequential(
=embedding_dim),
nn.LayerNorm(normalized_shape=embedding_dim, out_features=num_classes)
nn.Linear(in_features
)
def forward(self, x, layers=12):
= self.projections(x)
x = self.encoder_layer(x)
x = self.classifier(x[:, 0])
x return x
= ViT(in_channels = 3, patch_size = 4, embedding_dim = 768, batch_size=25, num_patches=64, mlp_size = 2048).to('cuda') vit
=(25, 3, 32, 32),col_names=["input_size"]) summary(vit, input_size
======================================================================
Layer (type:depth-idx) Input Shape
======================================================================
ViT [25, 3, 32, 32]
├─LinearProjectionBlock: 1-1 [25, 3, 32, 32]
│ └─Conv2d: 2-1 [25, 3, 32, 32]
│ └─Flatten: 2-2 [25, 768, 8, 8]
├─TransformerEncoder: 1-2 [25, 65, 768]
│ └─ModuleList: 2-3 --
│ │ └─TransformerEncoderLayer: 3-1 [25, 65, 768]
│ │ └─TransformerEncoderLayer: 3-2 [25, 65, 768]
│ │ └─TransformerEncoderLayer: 3-3 [25, 65, 768]
│ │ └─TransformerEncoderLayer: 3-4 [25, 65, 768]
│ │ └─TransformerEncoderLayer: 3-5 [25, 65, 768]
│ │ └─TransformerEncoderLayer: 3-6 [25, 65, 768]
│ │ └─TransformerEncoderLayer: 3-7 [25, 65, 768]
│ │ └─TransformerEncoderLayer: 3-8 [25, 65, 768]
│ │ └─TransformerEncoderLayer: 3-9 [25, 65, 768]
│ │ └─TransformerEncoderLayer: 3-10 [25, 65, 768]
│ │ └─TransformerEncoderLayer: 3-11 [25, 65, 768]
│ │ └─TransformerEncoderLayer: 3-12 [25, 65, 768]
├─Sequential: 1-3 [25, 768]
│ └─LayerNorm: 2-4 [25, 768]
│ └─Linear: 2-5 [25, 768]
======================================================================
Total params: 72,995,850
Trainable params: 72,995,850
Non-trainable params: 0
Total mult-adds (G): 1.01
======================================================================
Input size (MB): 0.31
Forward/backward pass size (MB): 688.90
Params size (MB): 151.47
Estimated Total Size (MB): 840.67
======================================================================
= vit(images.to('cuda')) test_output
test_output.shape
torch.Size([25, 10])
1]), labels[1] torch.argmax(test_output[
(tensor(0, device='cuda:0'), tensor(9))
= nn.CrossEntropyLoss()
loss_fn = torch.optim.Adam(vit.parameters(), lr=1e-3, betas=(0.9, 0.999), weight_decay=0.1)
optimizer = Accuracy(task="multiclass", num_classes=10) accuracy
def train_step(model, dataloader, loss_fn, optimizer):
model.train()
= 0, 0
train_loss, train_acc
for batch, (X, y) in enumerate(dataloader):
= X.to('cuda'), y.to('cuda')
X, y
= model(X)
y_pred
= loss_fn(y_pred, y)
loss += loss.item()
train_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
= torch.argmax(torch.softmax(y_pred, dim=1), dim=1)
y_pred_class += (y_pred_class == y).sum().item()/len(y_pred)
train_acc
= train_loss / len(dataloader)
train_loss = train_acc / len(dataloader)
train_acc return train_loss, train_acc
def test_step(model, dataloader, loss_fn):
eval()
model.
= 0, 0
test_loss, test_acc
with torch.inference_mode():
for batch, (X, y) in enumerate(dataloader):
= X.to('cuda'), y.to('cuda')
X, y
= model(X)
test_pred_logits
= loss_fn(test_pred_logits, y)
loss += loss.item()
test_loss
= test_pred_logits.argmax(dim=1)
test_pred_labels += ((test_pred_labels == y).sum().item()/len(test_pred_labels))
test_acc
= test_loss / len(dataloader)
test_loss = test_acc / len(dataloader)
test_acc return test_loss, test_acc
def train(model,train_dataloader,test_dataloader, optimizer, loss_fn, epochs = 50):
= {"train_loss": [],
results "train_acc": [],
"test_loss": [],
"test_acc": []
}
for epoch in tqdm(range(epochs)):
= train_step(model=model,
train_loss, train_acc =train_dataloader,
dataloader=loss_fn,
loss_fn=optimizer)
optimizer= test_step(model=model,
test_loss, test_acc =test_dataloader,
dataloader=loss_fn)
loss_fn
print(
f"Epoch: {epoch+1} | "
f"train_loss: {train_loss:.4f} | "
f"train_acc: {train_acc:.4f} | "
f"test_loss: {test_loss:.4f} | "
f"test_acc: {test_acc:.4f}"
)
"train_loss"].append(train_loss)
results["train_acc"].append(train_acc)
results["test_loss"].append(test_loss)
results["test_acc"].append(test_acc)
results[
return results
= train(model=vit, train_dataloader=trainloader, test_dataloader=testloader, optimizer=optimizer, loss_fn=loss_fn, epochs=50) results
Epoch: 1 | train_loss: 2.3983 | train_acc: 0.1070 | test_loss: 2.3546 | test_acc: 0.0750
Epoch: 2 | train_loss: 2.4793 | train_acc: 0.1120 | test_loss: 2.4665 | test_acc: 0.1100
Epoch: 3 | train_loss: 2.4430 | train_acc: 0.1060 | test_loss: 2.3757 | test_acc: 0.0900
Epoch: 4 | train_loss: 2.3699 | train_acc: 0.1010 | test_loss: 2.3697 | test_acc: 0.1050
Epoch: 5 | train_loss: 2.3734 | train_acc: 0.1100 | test_loss: 2.3556 | test_acc: 0.1450
Epoch: 6 | train_loss: 2.3820 | train_acc: 0.0850 | test_loss: 2.3614 | test_acc: 0.0800
Epoch: 7 | train_loss: 2.3696 | train_acc: 0.0810 | test_loss: 2.3599 | test_acc: 0.1000
Epoch: 8 | train_loss: 2.3491 | train_acc: 0.0960 | test_loss: 2.3836 | test_acc: 0.1200
Epoch: 9 | train_loss: 2.3625 | train_acc: 0.1030 | test_loss: 2.3677 | test_acc: 0.0500
Epoch: 10 | train_loss: 2.3425 | train_acc: 0.1120 | test_loss: 2.3327 | test_acc: 0.1100
Epoch: 11 | train_loss: 2.3501 | train_acc: 0.1020 | test_loss: 2.3617 | test_acc: 0.1050
Epoch: 12 | train_loss: 2.3455 | train_acc: 0.0840 | test_loss: 2.3424 | test_acc: 0.0800
Epoch: 13 | train_loss: 2.3293 | train_acc: 0.1050 | test_loss: 2.3220 | test_acc: 0.1400
Epoch: 14 | train_loss: 2.3237 | train_acc: 0.0930 | test_loss: 2.3216 | test_acc: 0.1150
Epoch: 15 | train_loss: 2.3416 | train_acc: 0.0970 | test_loss: 2.3089 | test_acc: 0.1350
Epoch: 16 | train_loss: 2.3387 | train_acc: 0.0750 | test_loss: 2.3197 | test_acc: 0.0900
Epoch: 17 | train_loss: 2.3219 | train_acc: 0.0990 | test_loss: 2.3107 | test_acc: 0.0900
Epoch: 18 | train_loss: 2.3260 | train_acc: 0.1070 | test_loss: 2.3100 | test_acc: 0.1100
Epoch: 19 | train_loss: 2.3187 | train_acc: 0.1010 | test_loss: 2.3043 | test_acc: 0.1000
Epoch: 20 | train_loss: 2.3256 | train_acc: 0.1060 | test_loss: 2.3108 | test_acc: 0.1300
Epoch: 21 | train_loss: 2.3152 | train_acc: 0.1050 | test_loss: 2.3279 | test_acc: 0.1150
Epoch: 22 | train_loss: 2.3160 | train_acc: 0.1070 | test_loss: 2.3138 | test_acc: 0.1100
Epoch: 23 | train_loss: 2.3145 | train_acc: 0.1120 | test_loss: 2.3058 | test_acc: 0.0800
Epoch: 24 | train_loss: 2.3119 | train_acc: 0.0900 | test_loss: 2.3030 | test_acc: 0.1450
Epoch: 25 | train_loss: 2.3168 | train_acc: 0.0940 | test_loss: 2.3119 | test_acc: 0.0950
Epoch: 26 | train_loss: 2.3111 | train_acc: 0.0950 | test_loss: 2.3132 | test_acc: 0.0650
Epoch: 27 | train_loss: 2.3035 | train_acc: 0.1120 | test_loss: 2.3098 | test_acc: 0.0700
Epoch: 28 | train_loss: 2.3060 | train_acc: 0.1150 | test_loss: 2.3021 | test_acc: 0.1250
Epoch: 29 | train_loss: 2.3059 | train_acc: 0.1050 | test_loss: 2.3005 | test_acc: 0.0950
Epoch: 30 | train_loss: 2.3072 | train_acc: 0.1010 | test_loss: 2.3023 | test_acc: 0.0800
Epoch: 31 | train_loss: 2.3089 | train_acc: 0.0980 | test_loss: 2.3047 | test_acc: 0.0950
Epoch: 32 | train_loss: 2.3053 | train_acc: 0.1050 | test_loss: 2.3046 | test_acc: 0.0850
Epoch: 33 | train_loss: 2.3035 | train_acc: 0.1070 | test_loss: 2.3072 | test_acc: 0.0600
Epoch: 34 | train_loss: 2.3030 | train_acc: 0.1120 | test_loss: 2.3081 | test_acc: 0.1000
Epoch: 35 | train_loss: 2.3018 | train_acc: 0.1110 | test_loss: 2.2997 | test_acc: 0.1250
Epoch: 36 | train_loss: 2.3030 | train_acc: 0.1090 | test_loss: 2.3014 | test_acc: 0.1000
Epoch: 37 | train_loss: 2.3032 | train_acc: 0.0990 | test_loss: 2.2988 | test_acc: 0.1000
Epoch: 38 | train_loss: 2.3023 | train_acc: 0.1120 | test_loss: 2.3021 | test_acc: 0.1000
Epoch: 39 | train_loss: 2.3024 | train_acc: 0.1140 | test_loss: 2.3017 | test_acc: 0.1000
Epoch: 40 | train_loss: 2.3012 | train_acc: 0.1140 | test_loss: 2.3006 | test_acc: 0.1000
Epoch: 41 | train_loss: 2.3017 | train_acc: 0.1170 | test_loss: 2.3014 | test_acc: 0.1000
Epoch: 42 | train_loss: 2.3007 | train_acc: 0.1140 | test_loss: 2.3011 | test_acc: 0.1000
Epoch: 43 | train_loss: 2.3022 | train_acc: 0.1110 | test_loss: 2.3002 | test_acc: 0.1000
Epoch: 44 | train_loss: 2.3013 | train_acc: 0.1140 | test_loss: 2.3010 | test_acc: 0.1000
Epoch: 45 | train_loss: 2.3013 | train_acc: 0.1140 | test_loss: 2.3011 | test_acc: 0.1000
Epoch: 46 | train_loss: 2.3014 | train_acc: 0.1140 | test_loss: 2.3008 | test_acc: 0.1000
Epoch: 47 | train_loss: 2.3011 | train_acc: 0.1140 | test_loss: 2.3010 | test_acc: 0.1000
Epoch: 48 | train_loss: 2.3013 | train_acc: 0.1140 | test_loss: 2.3011 | test_acc: 0.1000
Epoch: 49 | train_loss: 2.3013 | train_acc: 0.1140 | test_loss: 2.3012 | test_acc: 0.1000
Epoch: 50 | train_loss: 2.3009 | train_acc: 0.1140 | test_loss: 2.3014 | test_acc: 0.1000
'train_loss'])
plt.plot(results['test_loss']) plt.plot(results[
'train_acc']) plt.plot(results[