"Open

In [10]:
import torch
import sys
import numpy as np
import os
import yaml
import matplotlib.pyplot as plt
import torchvision

In [11]:
!pip install gdown



In [12]:
def get_file_id_by_model(folder_name):
 file_id = {'resnet18_100-epochs_stl10': '14_nH2FkyKbt61cieQDiSbBVNP8-gtwgF',
 'resnet18_100-epochs_cifar10': '1lc2aoVtrAetGn0PnTkOyFzPCIucOJq7C'}
 return file_id.get(folder_name, "Model not found.")

In [13]:
folder_name = 'resnet18_100-epochs_cifar10'
file_id = get_file_id_by_model(folder_name)
print(folder_name, file_id)

resnet18_100-epochs_cifar10 1lc2aoVtrAetGn0PnTkOyFzPCIucOJq7C


In [14]:
# download and extract model files
os.system('gdown https://drive.google.com/uc?id={}'.format(file_id))
os.system('unzip {}'.format(folder_name))
!ls

checkpoint_0100.pth.tar
config.yml
events.out.tfevents.1610901418.4cb2c837708d.2683796.0
resnet18_100-epochs_cifar10.zip
resnet18_100-epochs-cifar10.zip
run.log
sample_data


In [15]:
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import datasets

In [16]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using device:", device)

Using device: cuda


In [17]:
def get_stl10_data_loaders(download, shuffle=False, batch_size=256):
 train_dataset = datasets.STL10('./data', split='train', download=download,
 transform=transforms.ToTensor())

 train_loader = DataLoader(train_dataset, batch_size=batch_size,
 num_workers=0, drop_last=False, shuffle=shuffle)
 
 test_dataset = datasets.STL10('./data', split='test', download=download,
 transform=transforms.ToTensor())

 test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
 num_workers=10, drop_last=False, shuffle=shuffle)
 return train_loader, test_loader

def get_cifar10_data_loaders(download, shuffle=False, batch_size=256):
 train_dataset = datasets.CIFAR10('./data', train=True, download=download,
 transform=transforms.ToTensor())

 train_loader = DataLoader(train_dataset, batch_size=batch_size,
 num_workers=0, drop_last=False, shuffle=shuffle)
 
 test_dataset = datasets.CIFAR10('./data', train=False, download=download,
 transform=transforms.ToTensor())

 test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
 num_workers=10, drop_last=False, shuffle=shuffle)
 return train_loader, test_loader

In [18]:
with open(os.path.join('./config.yml')) as file:
 config = yaml.load(file)

In [19]:
if config.arch == 'resnet18':
 model = torchvision.models.resnet18(pretrained=False, num_classes=10).to(device)
elif config.arch == 'resnet50':
 model = torchvision.models.resnet50(pretrained=False, num_classes=10).to(device)

In [20]:
checkpoint = torch.load('checkpoint_0100.pth.tar', map_location=device)
state_dict = checkpoint['state_dict']

for k in list(state_dict.keys()):

 if k.startswith('backbone.'):
 if k.startswith('backbone') and not k.startswith('backbone.fc'):
 # remove prefix
 state_dict[k[len("backbone."):]] = state_dict[k]
 del state_dict[k]

In [21]:
log = model.load_state_dict(state_dict, strict=False)
assert log.missing_keys == ['fc.weight', 'fc.bias']

In [22]:
if config.dataset_name == 'cifar10':
 train_loader, test_loader = get_cifar10_data_loaders(download=True)
elif config.dataset_name == 'stl10':
 train_loader, test_loader = get_stl10_data_loaders(download=True)
print("Dataset:", config.dataset_name)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Dataset: cifar10


In [23]:
# freeze all layers but the last fc
for name, param in model.named_parameters():
 if name not in ['fc.weight', 'fc.bias']:
 param.requires_grad = False

parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
assert len(parameters) == 2 # fc.weight, fc.bias

In [24]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.0008)
criterion = torch.nn.CrossEntropyLoss().to(device)

In [25]:
def accuracy(output, target, topk=(1,)):
 """Computes the accuracy over the k top predictions for the specified values of k"""
 with torch.no_grad():
 maxk = max(topk)
 batch_size = target.size(0)

 _, pred = output.topk(maxk, 1, True, True)
 pred = pred.t()
 correct = pred.eq(target.view(1, -1).expand_as(pred))

 res = []
 for k in topk:
 correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
 res.append(correct_k.mul_(100.0 / batch_size))
 return res

In [26]:
epochs = 100
for epoch in range(epochs):
 top1_train_accuracy = 0
 for counter, (x_batch, y_batch) in enumerate(train_loader):
 x_batch = x_batch.to(device)
 y_batch = y_batch.to(device)

 logits = model(x_batch)
 loss = criterion(logits, y_batch)
 top1 = accuracy(logits, y_batch, topk=(1,))
 top1_train_accuracy += top1[0]

 optimizer.zero_grad()
 loss.backward()
 optimizer.step()

 top1_train_accuracy /= (counter + 1)
 top1_accuracy = 0
 top5_accuracy = 0
 for counter, (x_batch, y_batch) in enumerate(test_loader):
 x_batch = x_batch.to(device)
 y_batch = y_batch.to(device)

 logits = model(x_batch)
 
 top1, top5 = accuracy(logits, y_batch, topk=(1,5))
 top1_accuracy += top1[0]
 top5_accuracy += top5[0]
 
 top1_accuracy /= (counter + 1)
 top5_accuracy /= (counter + 1)
 print(f"Epoch {epoch}\tTop1 Train accuracy {top1_train_accuracy.item()}\tTop1 Test accuracy: {top1_accuracy.item()}\tTop5 test acc: {top5_accuracy.item()}")

Epoch 0	Top1 Train accuracy 49.823020935058594	Top1 Test accuracy: 57.63786697387695	Top5 test acc: 94.96036529541016
Epoch 1	Top1 Train accuracy 59.0130729675293	Top1 Test accuracy: 59.57088851928711	Top5 test acc: 95.76114654541016
Epoch 2	Top1 Train accuracy 60.604671478271484	Top1 Test accuracy: 60.32686233520508	Top5 test acc: 96.07250213623047
Epoch 3	Top1 Train accuracy 61.547752380371094	Top1 Test accuracy: 61.19715118408203	Top5 test acc: 96.14946746826172
Epoch 4	Top1 Train accuracy 62.19586944580078	Top1 Test accuracy: 61.48035430908203	Top5 test acc: 96.37407684326172
Epoch 5	Top1 Train accuracy 62.677772521972656	Top1 Test accuracy: 61.784236907958984	Top5 test acc: 96.40337371826172
Epoch 6	Top1 Train accuracy 63.06640625	Top1 Test accuracy: 62.2346076965332	Top5 test acc: 96.50102996826172
Epoch 7	Top1 Train accuracy 63.40122604370117	Top1 Test accuracy: 62.52527618408203	Top5 test acc: 96.46196746826172
Epoch 8	Top1 Train accuracy 63.698577880859375	Top1 Test accuracy: 