Remove distillation folder
parent
33172fba61
commit
c38640be55
|
@ -1,45 +0,0 @@
|
|||
student:
|
||||
model_name: resnet
|
||||
|
||||
teacher:
|
||||
model_name: dinov2_vitg14
|
||||
out_dim: 1536
|
||||
|
||||
data_transform:
|
||||
n_global_crops: 2
|
||||
n_local_crops: 8
|
||||
global_crops_scale: [0.32, 1.0]
|
||||
local_crops_scale: [0.05, 0.32]
|
||||
global_crops_size: [224, 224]
|
||||
local_crops_size: [224, 224]
|
||||
|
||||
|
||||
data_loader:
|
||||
batch_size: 32
|
||||
num_workers: 8
|
||||
shuffle: true
|
||||
collate_fn: collate_data_and_cast
|
||||
|
||||
optimizer:
|
||||
type: AdamW
|
||||
lr: 0.0005
|
||||
|
||||
dino_loss:
|
||||
weight: 1
|
||||
student_temp: 0.1
|
||||
teacher_temp: 0.07
|
||||
|
||||
|
||||
train:
|
||||
max_epochs: 100
|
||||
save_checkpoint_freq: 10
|
||||
resume_checkpoint: null
|
||||
|
||||
|
||||
model_wrapper:
|
||||
model_type: resnet
|
||||
n_patches: 256
|
||||
target_feature:
|
||||
- res5
|
||||
feature_matcher_config:
|
||||
out_channels: 1536
|
|
@ -1,31 +0,0 @@
|
|||
|
||||
import os
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from PIL import Image
|
||||
from typing import Optional
|
||||
from torchvision import transforms
|
||||
|
||||
class GTA5Dataset(Dataset):
|
||||
def __init__(self, img_dir: str = "/home/arda/.cache/kagglehub/datasets/ardaerendoru/gtagta/versions/1/GTA5/GTA5/images", transform: Optional[transforms.Compose] = None):
|
||||
|
||||
self.img_dir = img_dir
|
||||
self.transform = transform
|
||||
|
||||
# Get all image files
|
||||
self.images = []
|
||||
for img_name in os.listdir(self.img_dir):
|
||||
if img_name.endswith(('.jpg', '.png')):
|
||||
self.images.append(os.path.join(self.img_dir, img_name))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img_path = self.images[idx]
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
|
||||
if self.transform:
|
||||
image = self.transform(image)
|
||||
|
||||
return image
|
|
@ -1,13 +0,0 @@
|
|||
import torch
|
||||
|
||||
def collate_data_and_cast(samples_list):
|
||||
n_global_crops = len(samples_list[0]["global_crops"])
|
||||
n_local_crops = len(samples_list[0]["local_crops"])
|
||||
|
||||
collated_global_crops = torch.stack([s["global_crops"][i] for i in range(n_global_crops) for s in samples_list])
|
||||
collated_local_crops = torch.stack([s["local_crops"][i] for i in range(n_local_crops) for s in samples_list])
|
||||
|
||||
return {
|
||||
"collated_global_crops": collated_global_crops,
|
||||
"collated_local_crops": collated_local_crops,
|
||||
}
|
|
@ -1,40 +0,0 @@
|
|||
import os
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from PIL import Image
|
||||
from typing import Optional
|
||||
from torchvision import transforms
|
||||
from datasets import load_dataset
|
||||
import datasets
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset, Subset
|
||||
import random
|
||||
|
||||
class ImageNetDataset(Dataset):
|
||||
def __init__(self, type='train', transform: Optional[transforms.Compose] = None, num_samples: Optional[int] = None):
|
||||
self.dataset = load_dataset('imagenet-1k', trust_remote_code=True)
|
||||
if type == 'train':
|
||||
self.dataset = self.dataset['train']
|
||||
elif type == 'validation':
|
||||
self.dataset = self.dataset['validation']
|
||||
else:
|
||||
self.dataset = self.dataset['test']
|
||||
|
||||
if num_samples is not None:
|
||||
# Randomly sample indices
|
||||
indices = random.sample(range(len(self.dataset)), num_samples)
|
||||
self.dataset = Subset(self.dataset, indices)
|
||||
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.transform:
|
||||
image = self.dataset[idx]['image']
|
||||
if image.mode != 'RGB':
|
||||
image = image.convert('RGB')
|
||||
return self.transform(image)
|
||||
else:
|
||||
return self.dataset[idx]['image']
|
|
@ -1,378 +0,0 @@
|
|||
from torch.utils.data import Dataset
|
||||
from typing import Tuple, List, Optional
|
||||
import torch
|
||||
from PIL import Image
|
||||
import os
|
||||
import numpy as np
|
||||
import random
|
||||
from albumentations import Compose
|
||||
import math
|
||||
import itertools
|
||||
import torch.nn.functional as F
|
||||
|
||||
class GTA5(Dataset):
|
||||
|
||||
|
||||
"""
|
||||
GTA5 Dataset class for loading and transforming GTA5 dataset images and labels for semantic segmentation tasks.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, GTA5_path: str, transform: Optional[Compose] = None, FDA: float = None):
|
||||
|
||||
|
||||
"""
|
||||
Initializes the GTA5 dataset class.
|
||||
|
||||
This constructor sets up the dataset for use, optionally applying Frequency Domain Adaptation (FDA) and other transformations to the data.
|
||||
|
||||
Args:
|
||||
GTA5_path (str): The root directory path where the GTA5 dataset is stored.
|
||||
transform (callable, optional): A function/transform that takes in an image and label and returns a transformed version. Defaults to None.
|
||||
FDA (float, optional): The beta value for Frequency Domain Adaptation. If None, FDA is not applied. Defaults to None.
|
||||
"""
|
||||
|
||||
|
||||
self.GTA5_path = GTA5_path
|
||||
self.transform = transform
|
||||
self.FDA = FDA
|
||||
self.data = self._load_data()
|
||||
self.color_to_id = get_color_to_id()
|
||||
self.target_images = self._load_target_images() if FDA else []
|
||||
|
||||
def _load_data(self)->List[Tuple[str, str]]:
|
||||
|
||||
"""
|
||||
Load data paths for GTA5 dataset images and labels.
|
||||
|
||||
This method walks through the directory structure of the GTA5 dataset, specifically looking for image files in the 'images' folder and corresponding label files in the 'labels' folder. It constructs a list of tuples, each containing the path to an image file and the corresponding label file.
|
||||
|
||||
Returns:
|
||||
list: A list of tuples, each containing the path to an image file and the corresponding label file.
|
||||
"""
|
||||
|
||||
data = []
|
||||
image_dir = os.path.join(self.GTA5_path, 'images')
|
||||
label_dir = os.path.join(self.GTA5_path, 'labels')
|
||||
for image_filename in os.listdir(image_dir):
|
||||
image_path = os.path.join(image_dir, image_filename)
|
||||
label_path = os.path.join(label_dir, image_filename)
|
||||
data.append((image_path, label_path))
|
||||
return data
|
||||
|
||||
def _load_target_images(self)->List[Tuple[str, str]]:
|
||||
|
||||
"""
|
||||
Load target images for Frequency Domain Adaptation.
|
||||
|
||||
This method walks through the directory structure of the Cityscapes dataset, specifically looking for image files in the 'gtFine' folder. It constructs a list of tuples, each containing the path to a label file and the corresponding image file.
|
||||
|
||||
Returns:
|
||||
list: A list of tuples, each containing the path to a label file and the corresponding image file.
|
||||
"""
|
||||
|
||||
target_images = []
|
||||
city_path = self.GTA5_path.replace('GTA5', 'Cityscapes')
|
||||
city_image_dir = os.path.join(city_path, 'Cityspaces', 'gtFine', 'train')
|
||||
for root, _, files in os.walk(city_image_dir):
|
||||
for file in files:
|
||||
if 'Id' in file:
|
||||
label_path = os.path.join(root, file)
|
||||
image_path = label_path.replace('gtFine/', 'images/').replace('_gtFine_labelTrainIds', '_leftImg8bit')
|
||||
target_images.append((label_path, image_path))
|
||||
return target_images
|
||||
|
||||
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
"""
|
||||
Get the image and label at the specified index.
|
||||
|
||||
Args:
|
||||
index (int): The index of the data point to retrieve.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing the transformed image and label.
|
||||
"""
|
||||
|
||||
img_path, label_path = self.data[index]
|
||||
img = Image.open(img_path).convert('RGB')
|
||||
label = self._convert_rgb_to_label(Image.open(label_path).convert('RGB'))
|
||||
img, label = np.array(img), np.array(label)
|
||||
center_padding = CenterPadding(14)
|
||||
|
||||
if self.FDA:
|
||||
target_image_path = random.choice(self.target_images)[1]
|
||||
target_image = Image.open(target_image_path).convert('RGB').resize(img.shape[1::-1])
|
||||
img = FDA_transform(img, np.array(target_image), beta=self.FDA)
|
||||
|
||||
if self.transform:
|
||||
|
||||
transformed = self.transform(image=img, mask=label)
|
||||
img, label = transformed['image'], transformed['mask']
|
||||
|
||||
|
||||
|
||||
|
||||
img = torch.from_numpy(img).permute(2, 0, 1).float()/255
|
||||
label = torch.from_numpy(label).long()
|
||||
return center_padding(img), center_padding(label)
|
||||
|
||||
def __len__(self)->int:
|
||||
|
||||
"""
|
||||
Get the number of data points in the dataset.
|
||||
|
||||
Returns:
|
||||
int: The number of data points in the dataset.
|
||||
"""
|
||||
|
||||
return len(self.data)
|
||||
|
||||
def _convert_rgb_to_label(self, img:Image.Image)->np.ndarray:
|
||||
|
||||
"""
|
||||
Convert RGB image to grayscale label.
|
||||
|
||||
Args:
|
||||
img (Image.Image): The RGB image to convert to grayscale.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The grayscale label image.
|
||||
"""
|
||||
|
||||
gray_img = Image.new('L', img.size)
|
||||
label_pixels = img.load()
|
||||
gray_pixels = gray_img.load()
|
||||
|
||||
for i in range(img.width):
|
||||
for j in range(img.height):
|
||||
rgb = label_pixels[i, j]
|
||||
gray_pixels[i, j] = self.color_to_id.get(rgb, 255)
|
||||
|
||||
return gray_img
|
||||
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
import numpy as np
|
||||
from scipy.ndimage import gaussian_filter
|
||||
from scipy.special import erfinv
|
||||
import PIL
|
||||
def fast_hist(a: np.ndarray, b: np.ndarray, n: int) -> np.ndarray:
|
||||
"""
|
||||
Compute a fast histogram for evaluating segmentation metrics.
|
||||
|
||||
This function calculates a 2D histogram where each entry (i, j) counts the number of pixels that have the true label i and the predicted label j with a mask.
|
||||
|
||||
Args:
|
||||
a (np.ndarray): An array of true labels.
|
||||
b (np.ndarray): An array of predicted labels.
|
||||
n (int): The number of different labels.
|
||||
|
||||
Returns:
|
||||
np.ndarray: A 2D histogram of size (n, n).
|
||||
"""
|
||||
k = (b >= 0) & (b < n)
|
||||
return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)
|
||||
|
||||
def per_class_iou(hist: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Calculate the Intersection over Union (IoU) for each class.
|
||||
|
||||
The IoU is computed for each class using the histogram of true and predicted labels. It is defined as the ratio of the diagonal elements of the histogram to the sum of the corresponding rows and columns, adjusted by the diagonal elements and a small epsilon to avoid division by zero.
|
||||
|
||||
Args:
|
||||
hist (np.ndarray): A 2D histogram where each entry (i, j) is the count of pixels with true label i and predicted label j.
|
||||
|
||||
Returns:
|
||||
np.ndarray: An array containing the IoU for each class.
|
||||
"""
|
||||
epsilon = 1e-5
|
||||
return (np.diag(hist)) / (hist.sum(1) + hist.sum(0) - np.diag(hist) + epsilon)
|
||||
|
||||
|
||||
def poly_lr_scheduler(optimizer:torch.optim.Optimizer, init_lr:float, iter:int, lr_decay_iter:int=1,
|
||||
max_iter:int=50, power:float=0.9)->float:
|
||||
"""
|
||||
Adjusts the learning rate of the optimizer for each iteration using a polynomial decay schedule.
|
||||
|
||||
This function updates the learning rate of the optimizer based on the current iteration number and a polynomial decay schedule. The learning rate is calculated using the formula:
|
||||
|
||||
lr = init_lr * (1 - iter/max_iter) ** power
|
||||
|
||||
where `init_lr` is the initial learning rate, `iter` is the current iteration number, `max_iter` is the maximum number of iterations, and `power` is the exponent used in the polynomial decay.
|
||||
|
||||
Args:
|
||||
optimizer (torch.optim.Optimizer): The optimizer for which to adjust the learning rate.
|
||||
init_lr (float): The initial learning rate.
|
||||
iter (int): The current iteration number.
|
||||
lr_decay_iter (int): The iteration interval after which the learning rate is decayed. Default is 1.
|
||||
max_iter (int): The maximum number of iterations after which no more decay will happen.
|
||||
power (float): The exponent used in the polynomial decay of the learning rate.
|
||||
|
||||
Returns:
|
||||
float: The updated learning rate.
|
||||
"""
|
||||
# if iter % lr_decay_iter or iter > max_iter:
|
||||
# return optimizer
|
||||
|
||||
# lr = init_lr*(1 - iter/max_iter)**power
|
||||
lr = init_lr*(1 - iter/max_iter)**power
|
||||
optimizer.param_groups[0]['lr'] = lr
|
||||
return lr
|
||||
|
||||
def label_to_rgb(label:np.ndarray, height:int, width:int)->PIL.Image:
|
||||
"""
|
||||
Transforms a label matrix into a corresponding RGB image utilizing a predefined color map.
|
||||
|
||||
This function maps each label identifier in a two-dimensional array to a specific color, thereby generating an RGB image. This is particularly useful for visualizing segmentation results where each label corresponds to a different segment class.
|
||||
|
||||
Parameters:
|
||||
label (np.ndarray): A two-dimensional array where each element represents a label identifier.
|
||||
height (int): The desired height of the resulting RGB image.
|
||||
width (int): The desired width of the resulting RGB image.
|
||||
|
||||
Returns:
|
||||
PIL.Image: An image object representing the RGB image constructed from the label matrix.
|
||||
"""
|
||||
id_to_color = get_id_to_color()
|
||||
|
||||
height, width = label.shape
|
||||
rgb_image = np.zeros((height, width, 3), dtype=np.uint8)
|
||||
for i in range(height):
|
||||
for j in range(width):
|
||||
class_id = label[i, j]
|
||||
rgb_image[i, j] = id_to_color.get(class_id, (255, 255, 255)) # Default to white if not found
|
||||
pil_image = Image.fromarray(rgb_image, 'RGB')
|
||||
return pil_image
|
||||
|
||||
def generate_cow_mask(img_size:tuple, sigma:float, p:float, batch_size:int)->np.ndarray:
|
||||
|
||||
"""
|
||||
Generates a batch of cow masks based on a Gaussian noise model.
|
||||
|
||||
Parameters:
|
||||
img_size (tuple): The size of the images (height, width).
|
||||
sigma (float): The standard deviation of the Gaussian filter applied to the noise.
|
||||
p (float): The desired proportion of the mask that should be 'cow'.
|
||||
batch_size (int): The number of masks to generate.
|
||||
|
||||
Returns:
|
||||
np.ndarray: A batch of cow masks of shape (batch_size, 1, height, width).
|
||||
"""
|
||||
N = np.random.normal(size=img_size)
|
||||
Ns = gaussian_filter(N, sigma)
|
||||
t = erfinv(p*2 - 1) * (2**0.5) * Ns.std() + Ns.mean()
|
||||
masks = []
|
||||
for i in range(batch_size):
|
||||
masks.append((Ns > t).astype(float).reshape(1,*img_size))
|
||||
return np.array(masks)
|
||||
|
||||
def get_id_to_label() -> dict:
|
||||
"""
|
||||
Returns a dictionary mapping class IDs to their corresponding labels.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary where keys are class IDs and values are labels.
|
||||
"""
|
||||
return {
|
||||
0: 'road',
|
||||
1: 'sidewalk',
|
||||
2: 'building',
|
||||
3: 'wall',
|
||||
4: 'fence',
|
||||
5: 'pole',
|
||||
6: 'light',
|
||||
7: 'sign',
|
||||
8: 'vegetation',
|
||||
9: 'terrain',
|
||||
10: 'sky',
|
||||
11: 'person',
|
||||
12: 'rider',
|
||||
13: 'car',
|
||||
14: 'truck',
|
||||
15: 'bus',
|
||||
16: 'train',
|
||||
17: 'motorcycle',
|
||||
18: 'bicycle',
|
||||
255: 'unlabeled'
|
||||
}
|
||||
|
||||
def get_id_to_color() -> dict:
|
||||
"""
|
||||
Returns a dictionary mapping class IDs to their corresponding colors.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary where keys are class IDs and values are RGB color tuples.
|
||||
"""
|
||||
id_to_color = {
|
||||
0: (128, 64, 128), # road
|
||||
1: (244, 35, 232), # sidewalk
|
||||
2: (70, 70, 70), # building
|
||||
3: (102, 102, 156), # wall
|
||||
4: (190, 153, 153), # fence
|
||||
5: (153, 153, 153), # pole
|
||||
6: (250, 170, 30), # light
|
||||
7: (220, 220, 0), # sign
|
||||
8: (107, 142, 35), # vegetation
|
||||
9: (152, 251, 152), # terrain
|
||||
10: (70, 130, 180), # sky
|
||||
11: (220, 20, 60), # person
|
||||
12: (255, 0, 0), # rider
|
||||
13: (0, 0, 142), # car
|
||||
14: (0, 0, 70), # truck
|
||||
15: (0, 60, 100), # bus
|
||||
16: (0, 80, 100), # train
|
||||
17: (0, 0, 230), # motorcycle
|
||||
18: (119, 11, 32), # bicycle
|
||||
}
|
||||
return id_to_color
|
||||
|
||||
def get_color_to_id() -> dict:
|
||||
"""
|
||||
Returns a dictionary mapping RGB color tuples to their corresponding class IDs.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary where keys are RGB color tuples and values are class IDs.
|
||||
"""
|
||||
id_to_color = get_id_to_color()
|
||||
color_to_id = {color: id for id, color in id_to_color.items()}
|
||||
return color_to_id
|
||||
|
||||
def mix(mask, data = None, target = None):
|
||||
#Mix
|
||||
if not (data is None):
|
||||
if mask.shape[0] == data.shape[0]:
|
||||
data = torch.cat([(mask[i] * data[i] + (1 - mask[i]) * data[(i + 1) % data.shape[0]]).unsqueeze(0) for i in range(data.shape[0])])
|
||||
elif mask.shape[0] == data.shape[0] / 2:
|
||||
data = torch.cat((torch.cat([(mask[i] * data[2 * i] + (1 - mask[i]) * data[2 * i + 1]).unsqueeze(0) for i in range(int(data.shape[0] / 2))]),
|
||||
torch.cat([((1 - mask[i]) * data[2 * i] + mask[i] * data[2 * i + 1]).unsqueeze(0) for i in range(int(data.shape[0] / 2))])))
|
||||
if not (target is None):
|
||||
target = torch.cat([(mask[i] * target[i] + (1 - mask[i]) * target[(i + 1) % target.shape[0]]).unsqueeze(0) for i in range(target.shape[0])])
|
||||
return data, target
|
||||
|
||||
|
||||
def generate_class_mask(pred, classes):
|
||||
pred, classes = torch.broadcast_tensors(pred.unsqueeze(0), classes.unsqueeze(1).unsqueeze(2))
|
||||
N = pred.eq(classes).sum(0)
|
||||
return N
|
||||
|
||||
class CenterPadding(torch.nn.Module):
|
||||
def __init__(self, multiple):
|
||||
super().__init__()
|
||||
self.multiple = multiple
|
||||
|
||||
def _get_pad(self, size):
|
||||
new_size = math.ceil(size / self.multiple) * self.multiple
|
||||
pad_size = new_size - size
|
||||
pad_size_left = pad_size // 2
|
||||
pad_size_right = pad_size - pad_size_left
|
||||
return pad_size_left, pad_size_right
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, x):
|
||||
pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
|
||||
output = F.pad(x, pads)
|
||||
return output
|
|
@ -1,3 +0,0 @@
|
|||
from .helper import get_teacher_and_student
|
||||
from .backbones import CustomResNet, DINOv2ViT
|
||||
from .resnet import ResNet
|
|
@ -1,64 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class FeatureMixerLayer(nn.Module):
|
||||
def __init__(self, in_dim, mlp_ratio=1):
|
||||
super().__init__()
|
||||
self.mix = nn.Sequential(
|
||||
nn.LayerNorm(in_dim),
|
||||
nn.Linear(in_dim, int(in_dim * mlp_ratio)),
|
||||
nn.ReLU(),
|
||||
nn.Linear(int(in_dim * mlp_ratio), in_dim),
|
||||
)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, (nn.Linear)):
|
||||
nn.init.trunc_normal_(m.weight, std=0.02)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.mix(x)
|
||||
|
||||
|
||||
class MixVPR(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels=1024,
|
||||
in_h=14,
|
||||
in_w=14,
|
||||
out_channels=512,
|
||||
mix_depth=4,
|
||||
mlp_ratio=1,
|
||||
out_rows=4,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.in_h = in_h # height of input feature maps
|
||||
self.in_w = in_w # width of input feature maps
|
||||
self.in_channels = in_channels # depth of input feature maps
|
||||
|
||||
self.out_channels = out_channels # depth wise projection dimension
|
||||
self.out_rows = out_rows # row wise projection dimesion
|
||||
|
||||
self.mix_depth = mix_depth # L the number of stacked FeatureMixers
|
||||
self.mlp_ratio = mlp_ratio # ratio of the mid projection layer in the mixer block
|
||||
|
||||
hw = in_h*in_w
|
||||
self.mix = nn.Sequential(*[
|
||||
FeatureMixerLayer(in_dim=hw, mlp_ratio=mlp_ratio)
|
||||
for _ in range(self.mix_depth)
|
||||
])
|
||||
self.channel_proj = nn.Linear(in_channels, out_channels)
|
||||
self.row_proj = nn.Linear(hw, out_rows)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.flatten(2)
|
||||
x = self.mix(x)
|
||||
x = x.permute(0, 2, 1)
|
||||
x = self.channel_proj(x)
|
||||
x = x.permute(0, 2, 1)
|
||||
x = self.row_proj(x)
|
||||
x = F.normalize(x.flatten(1), p=2, dim=1)
|
||||
return x
|
||||
|
|
@ -1,131 +0,0 @@
|
|||
from utils import match_vit_features
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torchvision.models as models
|
||||
import math
|
||||
from .aggregators import MixVPR
|
||||
|
||||
class CustomResNet(nn.Module):
|
||||
def __init__(self, model_name='resnet50', patch_size=14, pretrained=True):
|
||||
super().__init__()
|
||||
"""
|
||||
Custom ResNet model for distillation.
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the ResNet model to use.
|
||||
patch_size (int): The patch size of the ViT model.
|
||||
pretrained (bool): Whether to use pretrained weights.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing:
|
||||
feature_map: torch.Tensor (B, C, H, W): The feature map after the last layer of ResNet.
|
||||
embeddings: torch.Tensor (B, D): The embeddings after global average pooling.
|
||||
patches: torch.Tensor (B, N, C): The patches which match the ViT patch grid.
|
||||
"""
|
||||
|
||||
if pretrained:
|
||||
base_model = getattr(models, model_name)(weights='IMAGENET1K_V1')
|
||||
else:
|
||||
base_model = getattr(models, model_name)(weights=None)
|
||||
self.patch_size = patch_size
|
||||
|
||||
# Split the model into layers
|
||||
self.conv1 = base_model.conv1
|
||||
self.bn1 = base_model.bn1
|
||||
self.relu = base_model.relu
|
||||
self.maxpool = base_model.maxpool
|
||||
|
||||
self.layer1 = base_model.layer1
|
||||
self.layer2 = base_model.layer2
|
||||
self.layer3 = base_model.layer3
|
||||
self.layer4 = base_model.layer4
|
||||
|
||||
self.avgpool = base_model.avgpool
|
||||
in_channels = 2048 if model_name in ['resnet50', 'resnet101', 'resnet152'] else 512
|
||||
self.feature_matcher = nn.Conv2d(in_channels, 1536, kernel_size=1)
|
||||
|
||||
|
||||
# self.mix = MixVPR(in_channels=int(in_channels), in_h=7, in_w=7, out_channels=512, mix_depth=4, mlp_ratio=1, out_rows=4)
|
||||
self.attention = nn.MultiheadAttention(embed_dim=1536, num_heads=16, batch_first=True)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
layer_3 = self.layer3(x)
|
||||
layer4 = self.layer4(layer_3) # Final feature map before pooling
|
||||
feature_map = torch.nn.functional.interpolate(
|
||||
layer4,
|
||||
size=(16, 16),
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
feature_map = self.feature_matcher(feature_map)
|
||||
|
||||
|
||||
|
||||
|
||||
pooled = self.avgpool(feature_map) # Global average pooling
|
||||
embeddings = torch.flatten(pooled, 1) # Flatten to get embeddings
|
||||
|
||||
|
||||
# print(f"layer_3e shape: {layer_3.shape}")
|
||||
# contrastive_embeddings = self.mix(layer4)
|
||||
B, C, H, W = feature_map.shape
|
||||
tokens = feature_map.view(B, C, H * W).permute(0, 2, 1) # [B, T, C]
|
||||
attn_output, attn_weights = self.attention(tokens, tokens, tokens) # attn_weights: [B, T, T]
|
||||
return {
|
||||
'feature_map': layer4, # Final feature map after layer4
|
||||
'embedding': embeddings, # Embeddings after pooling
|
||||
# 'patch_embeddings': patches,
|
||||
'dinov2_feature_map': feature_map,
|
||||
# 'contrastive_embeddings': contrastive_embeddings,
|
||||
'attn_weights': attn_weights,
|
||||
}
|
||||
|
||||
class DINOv2ViT(nn.Module):
|
||||
def __init__(self, model_name='dinov2_vitg14'):
|
||||
super().__init__()
|
||||
"""
|
||||
DINOv2 ViT model for distillation.
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the DINOv2 ViT model to use.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing:
|
||||
patch_embeddings: torch.Tensor (B, N, D): The patch embeddings excluding the CLS token.
|
||||
cls_embedding: torch.Tensor (B, D): The CLS token embedding.
|
||||
"""
|
||||
# Load model from torch hub
|
||||
self.model = torch.hub.load('facebookresearch/dinov2', model_name)
|
||||
# Freeze all parameters
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad = False
|
||||
def forward(self, x):
|
||||
# Get features from the model's last layer
|
||||
patch_embeddings, cls_token = self.model.get_intermediate_layers(x, n=1, return_class_token=True)[0] # [B, N+1, D]
|
||||
# print(f" patch_embeddings shape: {patch_embeddings.shape}")
|
||||
# print(f" cls_token shape: {cls_token.shape}")
|
||||
# Convert patch embeddings to feature map format
|
||||
B, N, D = patch_embeddings.shape
|
||||
P = int(math.sqrt(N)) # -1 for cls token
|
||||
feature_map = patch_embeddings.reshape(B, P, P, D).permute(0, 3, 1, 2) # [B, D, P, P]
|
||||
|
||||
return {
|
||||
'patch_embeddings': patch_embeddings, # Per-patch embeddings excluding CLS
|
||||
'embedding': cls_token, # CLS token embedding
|
||||
'feature_map': feature_map,
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -1,69 +0,0 @@
|
|||
import sys
|
||||
sys.path.append("../../dinov2")
|
||||
from dinov2.layers.dino_head import DINOHead # Adjust based on actual location
|
||||
from .backbones import DINOv2ViT, CustomResNet
|
||||
import torch.nn as nn
|
||||
def get_teacher_and_student(cfg = "../config/config.yaml"):
|
||||
"""
|
||||
Initialize and return the teacher and student models based on the provided configuration.
|
||||
|
||||
This function creates the backbone and heads for both teacher and student models
|
||||
as specified in the configuration. It supports different architectures for the student,
|
||||
such as ResNet and DINO V2 Vision Transformer (ViT). The models are encapsulated
|
||||
within `nn.ModuleDict` objects for easy integration and training.
|
||||
|
||||
Args:
|
||||
cfg (str or dict): Path to the YAML configuration file or a dictionary
|
||||
containing configuration parameters. Defaults to "../config/config.yaml".
|
||||
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
- teacher (nn.ModuleDict): A module dictionary containing the teacher's backbone
|
||||
and associated heads (`dino_head`, `ibot_head`).
|
||||
- student (nn.ModuleDict): A module dictionary containing the student's backbone
|
||||
and associated heads (`dino_head`, `ibot_head`).
|
||||
|
||||
Raises:
|
||||
ValueError: If the student model name specified in the configuration is unsupported.
|
||||
"""
|
||||
student = {}
|
||||
teacher = {}
|
||||
|
||||
|
||||
|
||||
# Create teacher and student backbones
|
||||
teacher_backbone = DINOv2ViT(model_name=cfg["teacher"]["model_name"])
|
||||
teacher['backbone'] = teacher_backbone
|
||||
|
||||
if cfg["student"]["model_name"].startswith('resnet'):
|
||||
student_backbone = CustomResNet(model_name=cfg["student"]["model_name"], pretrained=True)
|
||||
|
||||
elif cfg["student"]["model_name"].startswith('dino'):
|
||||
student_backbone = DINOv2ViT(model_name=cfg["student"]["model_name"])
|
||||
else:
|
||||
raise ValueError(f"Unsupported student model: {cfg['student']['model_name']}")
|
||||
|
||||
student['backbone'] = student_backbone
|
||||
|
||||
student["dino_head"] = DINOHead(in_dim=cfg["student"]["dino_head"]["in_dim"],
|
||||
out_dim=cfg["student"]["dino_head"]["out_dim"],
|
||||
hidden_dim=cfg["student"]["dino_head"]["hidden_dim"],
|
||||
bottleneck_dim=cfg["student"]["dino_head"]["bottleneck_dim"])
|
||||
teacher["dino_head"] = DINOHead(in_dim=cfg["teacher"]["dino_head"]["in_dim"],
|
||||
out_dim=cfg["teacher"]["dino_head"]["out_dim"],
|
||||
hidden_dim=cfg["teacher"]["dino_head"]["hidden_dim"],
|
||||
bottleneck_dim=cfg["teacher"]["dino_head"]["bottleneck_dim"])
|
||||
student["ibot_head"] = DINOHead(in_dim=cfg["student"]["ibot_head"]["in_dim"],
|
||||
out_dim=cfg["student"]["ibot_head"]["out_dim"],
|
||||
hidden_dim=cfg["student"]["ibot_head"]["hidden_dim"],
|
||||
bottleneck_dim=cfg["student"]["ibot_head"]["bottleneck_dim"])
|
||||
teacher["ibot_head"] = DINOHead(in_dim=cfg["teacher"]["ibot_head"]["in_dim"],
|
||||
out_dim=cfg["teacher"]["ibot_head"]["out_dim"],
|
||||
hidden_dim=cfg["teacher"]["ibot_head"]["hidden_dim"],
|
||||
bottleneck_dim=cfg["teacher"]["ibot_head"]["bottleneck_dim"])
|
||||
|
||||
student = nn.ModuleDict(student)
|
||||
teacher = nn.ModuleDict(teacher)
|
||||
|
||||
return teacher, student
|
|
@ -1,34 +0,0 @@
|
|||
import torch.nn as nn
|
||||
from ...distillation_real.models.students.resnet import ResNet, BottleneckBlock, BasicStem, make_resnet_stages
|
||||
# Import other student models when added
|
||||
# from .efficientnet import EfficientNet, ...
|
||||
|
||||
class StudentModelFactory:
|
||||
"""
|
||||
Factory class for creating student models.
|
||||
"""
|
||||
@staticmethod
|
||||
def create_student(model_type: str, config: dict, device: torch.device) -> nn.Module:
|
||||
if model_type.lower() == 'resnet':
|
||||
stem = BasicStem(
|
||||
in_channels=config.get('in_channels', 3),
|
||||
out_channels=64,
|
||||
norm=config.get('norm_type', 'BN')
|
||||
)
|
||||
stages = make_resnet_stages(
|
||||
depth=config.get('model_depth', 50),
|
||||
block_class=config.get('block_class', BottleneckBlock),
|
||||
norm=config.get('norm_type', 'BN'),
|
||||
dilation=config.get('dilation', (1, 1, 1, 1))
|
||||
)
|
||||
student = ResNet(
|
||||
stem=stem,
|
||||
stages=stages,
|
||||
out_features=None,
|
||||
freeze_at=config.get('freeze_at', 0)
|
||||
).to(device)
|
||||
return student
|
||||
# elif model_type.lower() == 'efficientnet':
|
||||
# return EfficientNet(config).to(device)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {model_type}")
|
|
@ -1,127 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from ...distillation_real.models.students.resnet import ResNet, BasicStem, BottleneckBlock, make_resnet_stages
|
||||
import numpy as np
|
||||
class ModelWrapper(nn.Module):
|
||||
"""
|
||||
A generalizable and parametric wrapper for student models and feature matching.
|
||||
|
||||
Args:
|
||||
model_type (str): Type of the model to initialize (e.g., 'resnet').
|
||||
model_depth (int, optional): Depth of the ResNet model (default: 50).
|
||||
block_class (nn.Module, optional): Block class for ResNet (default: BottleneckBlock).
|
||||
norm_type (str, optional): Normalization type (default: 'BN').
|
||||
in_channels (int, optional): Number of input channels (default: 3).
|
||||
feature_matcher_config (dict, optional): Configuration for the feature matcher.
|
||||
Should include 'in_channels', 'out_channels', 'kernel_size', and other relevant parameters.
|
||||
device (torch.device, optional): Device to load the model on (default: 'cpu').
|
||||
|
||||
Attributes:
|
||||
student (nn.Module): The student model.
|
||||
feature_matcher (nn.Module): The feature matcher module.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model_type='resnet',
|
||||
model_depth=50,
|
||||
block_class=BottleneckBlock,
|
||||
norm_type='BN',
|
||||
in_channels=3,
|
||||
n_patches=256,
|
||||
feature_matcher_config=None,
|
||||
device=torch.device('cpu')
|
||||
):
|
||||
super(ModelWrapper, self).__init__()
|
||||
self.device = device
|
||||
self.n_patches = n_patches
|
||||
# Initialize the student model based on the specified type
|
||||
if model_type.lower() == 'resnet':
|
||||
stem = BasicStem(in_channels=in_channels, out_channels=64, norm=norm_type)
|
||||
stages = make_resnet_stages(
|
||||
depth=model_depth,
|
||||
block_class=block_class,
|
||||
norm=norm_type,
|
||||
dilation=(1, 1, 1, 1)
|
||||
)
|
||||
self.student = ResNet(
|
||||
stem=stem,
|
||||
stages=stages,
|
||||
out_features=None,
|
||||
freeze_at=0
|
||||
).to(self.device)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {model_type}")
|
||||
|
||||
# Initialize the feature matcher if configuration is provided
|
||||
if feature_matcher_config:
|
||||
self.feature_matcher = self._initialize_feature_matcher(feature_matcher_config)
|
||||
else:
|
||||
self.feature_matcher = None
|
||||
|
||||
def _initialize_feature_matcher(self, config):
|
||||
"""
|
||||
Initializes the feature matcher based on the provided configuration.
|
||||
|
||||
Args:
|
||||
config (dict): Configuration dictionary for the feature matcher.
|
||||
|
||||
Returns:
|
||||
nn.Module: Initialized feature matcher module.
|
||||
"""
|
||||
layers = []
|
||||
in_channels = config.get('in_channels', 2048)
|
||||
out_channels = config.get('out_channels', 1536)
|
||||
kernel_size = config.get('kernel_size', 1)
|
||||
stride = config.get('stride', 1)
|
||||
padding = config.get('padding', 0)
|
||||
activation = config.get('activation', None)
|
||||
|
||||
layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding))
|
||||
if activation:
|
||||
layers.append(getattr(nn, activation)())
|
||||
|
||||
return nn.Sequential(*layers).to(self.device)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the student model and feature matcher.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
tuple: (matched_features, student_output)
|
||||
"""
|
||||
# Forward pass through the student model
|
||||
student_output = self.student(x)
|
||||
|
||||
# If feature matcher is defined, process the feature maps
|
||||
if self.feature_matcher and 'res5' in student_output:
|
||||
res5_feature = student_output['res5']
|
||||
interpolated_feature = torch.nn.functional.interpolate(
|
||||
res5_feature,
|
||||
size=(int(np.sqrt(self.n_patches)), int(np.sqrt(self.n_patches))),
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
matched_features = self.feature_matcher(interpolated_feature)
|
||||
return matched_features, student_output
|
||||
else:
|
||||
return None, student_output
|
||||
|
||||
def get_feature_map_shape(self, x, feature_key='res5'):
|
||||
"""
|
||||
Utility method to get the shape of a specified feature map.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
feature_key (str, optional): Key of the feature map in student output (default: 'res5').
|
||||
|
||||
Returns:
|
||||
torch.Size: Shape of the specified feature map.
|
||||
"""
|
||||
output = self.student(x)
|
||||
if feature_key in output:
|
||||
return output[feature_key].shape
|
||||
else:
|
||||
raise KeyError(f"Feature key '{feature_key}' not found in student output.")
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -1,494 +0,0 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/storage/disk0/arda/dinov2/distillation/../../dinov2/dinov2/layers/swiglu_ffn.py:43: UserWarning: xFormers is available (SwiGLU)\n",
|
||||
" warnings.warn(\"xFormers is available (SwiGLU)\")\n",
|
||||
"/storage/disk0/arda/dinov2/distillation/../../dinov2/dinov2/layers/attention.py:27: UserWarning: xFormers is available (Attention)\n",
|
||||
" warnings.warn(\"xFormers is available (Attention)\")\n",
|
||||
"/storage/disk0/arda/dinov2/distillation/../../dinov2/dinov2/layers/block.py:33: UserWarning: xFormers is available (Block)\n",
|
||||
" warnings.warn(\"xFormers is available (Block)\")\n",
|
||||
"Using cache found in /home/arda/.cache/torch/hub/facebookresearch_dinov2_main\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"layer_3 shape: torch.Size([1, 1024, 14, 14])\n",
|
||||
"torch.Size([1, 1536, 16, 16])\n",
|
||||
"torch.Size([1, 1536])\n",
|
||||
"torch.Size([1, 2048])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from models import DINOv2ViT, CustomResNet\n",
|
||||
"import torch\n",
|
||||
"device = torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
"\n",
|
||||
"x = torch.randn(1, 3, 224, 224).to(device)\n",
|
||||
"teacher = DINOv2ViT().to(device)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# out = teacher(x)\n",
|
||||
"# print(out[\"patch_embeddings\"].shape)\n",
|
||||
"# print(out[\"embedding\"].shape)\n",
|
||||
"# print(out[\"feature_map\"].shape)\n",
|
||||
"\n",
|
||||
"student = CustomResNet().to(device)\n",
|
||||
"out = student(x)\n",
|
||||
"print(out[\"dinov2_feature_map\"].shape)\n",
|
||||
"print(out[\"embedding\"].shape)\n",
|
||||
"print(out[\"contrastive_embeddings\"].shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "2d696e26ecff4b7abda04f1dc053b276",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Loading dataset shards: 0%| | 0/257 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "e9d0656970474a3481a599b8ed19b4e6",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Loading dataset shards: 0%| | 0/25 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "5c1f7bd7cd8e422d99c4e3331582fd91",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Loading dataset shards: 0%| | 0/257 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "b91727b023b24d7cb7dd098ddc80f464",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Loading dataset shards: 0%| | 0/25 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# from datasets.GTA5 import GTA5Dataset\n",
|
||||
"import sys\n",
|
||||
"sys.path.append('./datasets') # Add the datasets directory to the Python path\n",
|
||||
"\n",
|
||||
"from collate_fn import collate_data_and_cast # Adjusted import statement\n",
|
||||
"# from datasets.collate_fn import collate_data_and_cast\n",
|
||||
"from dinov2.data.augmentations import DataAugmentationDINO\n",
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"from imagenet import ImageNetDataset\n",
|
||||
"\n",
|
||||
"import yaml\n",
|
||||
"\n",
|
||||
"# Load configurations\n",
|
||||
"with open(\"config/config.yaml\", \"r\") as f:\n",
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
"\n",
|
||||
"# Data Transformation\n",
|
||||
"data_transform = DataAugmentationDINO(\n",
|
||||
" global_crops_scale=tuple(cfg['data_transform']['global_crops_scale']),\n",
|
||||
" local_crops_scale=tuple(cfg['data_transform']['local_crops_scale']),\n",
|
||||
" local_crops_number=cfg['data_transform']['n_local_crops'],\n",
|
||||
" global_crops_size=tuple(cfg['data_transform']['global_crops_size']),\n",
|
||||
" local_crops_size=tuple(cfg['data_transform']['local_crops_size']),\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Create train and test datasets\n",
|
||||
"train_dataset = ImageNetDataset(type='train', transform=data_transform, num_samples = 5000)\n",
|
||||
"test_dataset = ImageNetDataset(type='test', transform=data_transform, num_samples = 500)\n",
|
||||
"# Create train and test dataloaders\n",
|
||||
"train_loader = DataLoader(\n",
|
||||
" train_dataset,\n",
|
||||
" batch_size=cfg['data_loader']['batch_size'], \n",
|
||||
" num_workers=cfg['data_loader']['num_workers'],\n",
|
||||
" shuffle=cfg['data_loader']['shuffle'],\n",
|
||||
" collate_fn=collate_data_and_cast\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"test_loader = DataLoader(\n",
|
||||
" test_dataset,\n",
|
||||
" batch_size=cfg['data_loader']['batch_size'],\n",
|
||||
" num_workers=cfg['data_loader']['num_workers'], \n",
|
||||
" shuffle=False,\n",
|
||||
" collate_fn=collate_data_and_cast\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Optimizer\n",
|
||||
"optimizer = getattr(torch.optim, cfg['optimizer']['type'])([\n",
|
||||
" {\"params\": student.parameters()},\n",
|
||||
"], lr=2.5e-4)\n",
|
||||
"\n",
|
||||
"# Freeze teacher model\n",
|
||||
"for param in teacher.parameters():\n",
|
||||
" param.requires_grad = False\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" 0%| | 0/16 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/arda/miniconda3/envs/dinov2/lib/python3.9/site-packages/xformers/ops/unbind.py:46: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
|
||||
" storage_data_ptr = tensors[0].storage().data_ptr()\n",
|
||||
"/home/arda/miniconda3/envs/dinov2/lib/python3.9/site-packages/xformers/ops/unbind.py:48: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
|
||||
" if x.storage().data_ptr() != storage_data_ptr:\n",
|
||||
"100%|██████████| 16/16 [01:01<00:00, 3.82s/it]\n",
|
||||
"100%|██████████| 2/2 [00:12<00:00, 6.46s/it]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 0\n",
|
||||
"Train Loss: 5.8403\n",
|
||||
"Train Feature Similarity: 0.0879\n",
|
||||
"Train Embedding Similarity: 0.0714\n",
|
||||
"Test Loss: 4.1575\n",
|
||||
"Test Similarity: 0.0881\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 16/16 [01:03<00:00, 3.98s/it]\n",
|
||||
"100%|██████████| 2/2 [00:12<00:00, 6.50s/it]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 1\n",
|
||||
"Train Loss: 5.5363\n",
|
||||
"Train Feature Similarity: 0.1243\n",
|
||||
"Train Embedding Similarity: 0.0848\n",
|
||||
"Test Loss: 3.9246\n",
|
||||
"Test Similarity: 0.1102\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 16/16 [01:05<00:00, 4.12s/it]\n",
|
||||
" 50%|█████ | 1/2 [00:10<00:10, 10.78s/it]"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from tqdm import tqdm\n",
|
||||
"import os\n",
|
||||
"import torch.nn.functional as F # Added import for functional operations\n",
|
||||
"from torch.cuda.amp import GradScaler, autocast # Import for mixed precision\n",
|
||||
"\n",
|
||||
"best_test_similarity = 0\n",
|
||||
"save_frequency = 5\n",
|
||||
"checkpoint_dir = \"./checkpoints\"\n",
|
||||
"scaler = GradScaler()\n",
|
||||
"\n",
|
||||
"def compute_feature_similarity(feat1, feat2):\n",
|
||||
" # Reshape feature maps to 2D: (batch*height*width, channels)\n",
|
||||
" f1 = feat1.reshape(-1, feat1.shape[-1])\n",
|
||||
" f2 = feat2.reshape(-1, feat2.shape[-1])\n",
|
||||
" \n",
|
||||
" # Compute cosine similarity\n",
|
||||
" similarity = torch.nn.functional.cosine_similarity(f1, f2, dim=1)\n",
|
||||
" return similarity.mean()\n",
|
||||
"\n",
|
||||
"def compound_loss(mse_loss, cosine_sim_loss, alpha=1.0, beta=1.0):\n",
|
||||
" \"\"\"\n",
|
||||
" Combine MSE loss and Cosine Similarity loss.\n",
|
||||
" \n",
|
||||
" Args:\n",
|
||||
" mse_loss (torch.Tensor): Mean Squared Error loss.\n",
|
||||
" cosine_sim_loss (torch.Tensor): Cosine Similarity loss.\n",
|
||||
" alpha (float): Weight for MSE loss.\n",
|
||||
" beta (float): Weight for Cosine Similarity loss.\n",
|
||||
" \n",
|
||||
" Returns:\n",
|
||||
" torch.Tensor: Combined loss.\n",
|
||||
" \"\"\"\n",
|
||||
" return alpha * mse_loss + beta * cosine_sim_loss\n",
|
||||
"\n",
|
||||
"checkpoint_path = os.path.join(checkpoint_dir, \"latest_checkpoint.pth\")\n",
|
||||
"\n",
|
||||
"if os.path.exists(checkpoint_path):\n",
|
||||
" checkpoint = torch.load(checkpoint_path)\n",
|
||||
" student.load_state_dict(checkpoint['student_state_dict'])\n",
|
||||
" optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
|
||||
" start_epoch = checkpoint['epoch'] + 1\n",
|
||||
" best_test_similarity = checkpoint['best_test_similarity']\n",
|
||||
" print(f\"Resuming from epoch {start_epoch}\")\n",
|
||||
"\n",
|
||||
"for epoch in range(1000):\n",
|
||||
" epoch_loss = []\n",
|
||||
" similarities = []\n",
|
||||
" embedding_similarities = []\n",
|
||||
" student.train()\n",
|
||||
" teacher.eval()\n",
|
||||
" for i, data in enumerate(tqdm(train_loader)):\n",
|
||||
" global_crops = data[\"collated_global_crops\"].to(device)\n",
|
||||
" local_crops = data[\"collated_local_crops\"].to(device)\n",
|
||||
"\n",
|
||||
" # Mixed precision training\n",
|
||||
" with autocast():\n",
|
||||
" # Get feature maps from teacher\n",
|
||||
" with torch.no_grad():\n",
|
||||
" teacher_output = teacher(global_crops)\n",
|
||||
" teacher_feature_maps = teacher_output[\"feature_map\"]\n",
|
||||
" teacher_embedding = teacher_output[\"embedding\"]\n",
|
||||
"\n",
|
||||
" # Get feature maps from student\n",
|
||||
" student_output = student(global_crops)\n",
|
||||
" student_feature_maps = student_output[\"dinov2_feature_map\"]\n",
|
||||
" student_embedding = student_output[\"embedding\"]\n",
|
||||
"\n",
|
||||
" # Calculate MSE loss between feature maps\n",
|
||||
" mse_loss = torch.nn.functional.mse_loss(\n",
|
||||
" student_feature_maps,\n",
|
||||
" teacher_feature_maps\n",
|
||||
" )\n",
|
||||
" mse_embedding_loss = torch.nn.functional.mse_loss(\n",
|
||||
" student_embedding,\n",
|
||||
" teacher_embedding\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" # Calculate Cosine Similarity loss\n",
|
||||
" student_feature_normalized = F.normalize(student_feature_maps, p=2, dim=1)\n",
|
||||
" teacher_feature_normalized = F.normalize(teacher_feature_maps, p=2, dim=1)\n",
|
||||
" cosine_similarity = torch.nn.functional.cosine_similarity(\n",
|
||||
" student_feature_normalized, \n",
|
||||
" teacher_feature_normalized, \n",
|
||||
" dim=1\n",
|
||||
" )\n",
|
||||
" cosine_similarity_loss = 1 - cosine_similarity.mean() # Convert similarity to loss\n",
|
||||
"\n",
|
||||
" student_embedding_normalized = F.normalize(student_embedding, p=2, dim=1)\n",
|
||||
" teacher_embedding_normalized = F.normalize(teacher_embedding, p=2, dim=1)\n",
|
||||
" cosine_similarity_embedding = torch.nn.functional.cosine_similarity(\n",
|
||||
" student_embedding_normalized, \n",
|
||||
" teacher_embedding_normalized, \n",
|
||||
" dim=1\n",
|
||||
" )\n",
|
||||
" cosine_similarity_embedding_loss = 1 - cosine_similarity_embedding.mean() # Convert similarity to loss\n",
|
||||
"\n",
|
||||
" # Combine the losses\n",
|
||||
" total_loss = compound_loss(mse_loss, cosine_similarity_loss, alpha=1.0, beta=1.0)\n",
|
||||
" total_embedding_loss = compound_loss(mse_embedding_loss, cosine_similarity_embedding_loss, alpha=1.0, beta=1.0)\n",
|
||||
" total_loss += total_embedding_loss\n",
|
||||
" scaler.scale(total_loss).backward()\n",
|
||||
" scaler.step(optimizer)\n",
|
||||
" scaler.update()\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
"\n",
|
||||
" # Calculate similarity for logging\n",
|
||||
" similarity = compute_feature_similarity(student_feature_maps, teacher_feature_maps)\n",
|
||||
" similarities.append(similarity.item())\n",
|
||||
" # Calculate embedding similarity for logging\n",
|
||||
" embedding_similarity = compute_feature_similarity(student_embedding, teacher_embedding)\n",
|
||||
" embedding_similarities.append(embedding_similarity.item())\n",
|
||||
" \n",
|
||||
" epoch_loss.append(total_loss.item())\n",
|
||||
"\n",
|
||||
" # Evaluation on test set\n",
|
||||
" student.eval()\n",
|
||||
" test_losses = []\n",
|
||||
" test_similarities = []\n",
|
||||
" test_embedding_similarities = []\n",
|
||||
" with torch.no_grad():\n",
|
||||
" for i, data in enumerate(tqdm(test_loader)):\n",
|
||||
" global_crops = data[\"collated_global_crops\"].to(device)\n",
|
||||
" \n",
|
||||
" teacher_output = teacher(global_crops)\n",
|
||||
" student_output = student(global_crops)\n",
|
||||
" \n",
|
||||
" # Feature map losses\n",
|
||||
" test_mse = torch.nn.functional.mse_loss(\n",
|
||||
" student_output[\"dinov2_feature_map\"],\n",
|
||||
" teacher_output[\"feature_map\"]\n",
|
||||
" )\n",
|
||||
" test_similarity = compute_feature_similarity(\n",
|
||||
" student_output[\"dinov2_feature_map\"],\n",
|
||||
" teacher_output[\"feature_map\"]\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" # Embedding losses\n",
|
||||
" test_embedding_mse = torch.nn.functional.mse_loss(\n",
|
||||
" student_output[\"embedding\"],\n",
|
||||
" teacher_output[\"embedding\"]\n",
|
||||
" )\n",
|
||||
" test_embedding_similarity = compute_feature_similarity(\n",
|
||||
" student_output[\"embedding\"],\n",
|
||||
" teacher_output[\"embedding\"]\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" test_losses.append(test_mse.item() + test_embedding_mse.item())\n",
|
||||
" test_similarities.append(test_similarity.item())\n",
|
||||
" test_embedding_similarities.append(test_embedding_similarity.item())\n",
|
||||
"\n",
|
||||
" # Calculate average metrics\n",
|
||||
" avg_train_loss = sum(epoch_loss)/len(epoch_loss)\n",
|
||||
" avg_train_similarity = sum(similarities)/len(similarities)\n",
|
||||
" avg_train_embedding_similarity = sum(embedding_similarities)/len(embedding_similarities)\n",
|
||||
" avg_test_loss = sum(test_losses)/len(test_losses)\n",
|
||||
" avg_test_similarity = sum(test_similarities)/len(test_similarities)\n",
|
||||
" avg_test_embedding_similarity = sum(test_embedding_similarities)/len(test_embedding_similarities)\n",
|
||||
"\n",
|
||||
" # Print metrics\n",
|
||||
" print(f\"Epoch {epoch}\")\n",
|
||||
" print(f\"Train Loss: {avg_train_loss:.4f}\")\n",
|
||||
" print(f\"Train Feature Similarity: {avg_train_similarity:.4f}\")\n",
|
||||
" print(f\"Train Embedding Similarity: {avg_train_embedding_similarity:.4f}\")\n",
|
||||
" print(f\"Test Loss: {avg_test_loss:.4f}\")\n",
|
||||
" print(f\"Test Feature Similarity: {avg_test_similarity:.4f}\")\n",
|
||||
" print(f\"Test Embedding Similarity: {avg_test_embedding_similarity:.4f}\")\n",
|
||||
"\n",
|
||||
" # Save checkpoint\n",
|
||||
" if (epoch + 1) % save_frequency == 0:\n",
|
||||
" checkpoint = {\n",
|
||||
" 'epoch': epoch,\n",
|
||||
" 'student_state_dict': student.state_dict(),\n",
|
||||
" 'optimizer_state_dict': optimizer.state_dict(),\n",
|
||||
" 'train_loss': avg_train_loss,\n",
|
||||
" 'test_loss': avg_test_loss,\n",
|
||||
" 'train_feature_similarity': avg_train_similarity,\n",
|
||||
" 'train_embedding_similarity': avg_train_embedding_similarity,\n",
|
||||
" 'test_similarity': avg_test_similarity,\n",
|
||||
" 'best_test_similarity': best_test_similarity\n",
|
||||
" }\n",
|
||||
" torch.save(checkpoint, os.path.join(checkpoint_dir, f\"checkpoint_epoch_{epoch}.pth\"))\n",
|
||||
" torch.save(checkpoint, os.path.join(checkpoint_dir, \"latest_checkpoint.pth\"))\n",
|
||||
" \n",
|
||||
" # Save best model\n",
|
||||
" if avg_test_similarity > best_test_similarity:\n",
|
||||
" best_test_similarity = avg_test_similarity\n",
|
||||
" torch.save({\n",
|
||||
" 'epoch': epoch,\n",
|
||||
" 'student_state_dict': student.state_dict(),\n",
|
||||
" 'test_similarity': avg_test_similarity,\n",
|
||||
" 'test_embedding_similarity': avg_test_similarity\n",
|
||||
" }, os.path.join(checkpoint_dir, \"best_model.pth\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"labels = torch.arange(32).repeat(2) # Creates [0,1,2,...,batch_size-1, 0,1,2,...,batch_size-1]\n",
|
||||
"labels = labels.to(device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
|
||||
" 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3,\n",
|
||||
" 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,\n",
|
||||
" 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], device='cuda:1')"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"labels"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "dinov2",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.20"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -1,957 +0,0 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from datasets_gta5 import GTA5\n",
|
||||
"import albumentations as A\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"GTA5_PATH = '/home/arda/.cache/kagglehub/datasets/ardaerendoru/gtagta/versions/1/GTA5/GTA5'\n",
|
||||
"GTA5_IMAGES = os.path.join(GTA5_PATH, 'images')\n",
|
||||
"GTA5_LABELS = os.path.join(GTA5_PATH, 'labels')\n",
|
||||
"\n",
|
||||
"device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"transform = A.Compose([\n",
|
||||
" A.Resize(512, 1024)\n",
|
||||
"])\n",
|
||||
"GTA5_dataset = GTA5(GTA5_path=GTA5_PATH, transform=transform)\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/storage/disk0/arda/dinov2/distillation/../../dinov2/dinov2/layers/swiglu_ffn.py:43: UserWarning: xFormers is available (SwiGLU)\n",
|
||||
" warnings.warn(\"xFormers is available (SwiGLU)\")\n",
|
||||
"/storage/disk0/arda/dinov2/distillation/../../dinov2/dinov2/layers/attention.py:27: UserWarning: xFormers is available (Attention)\n",
|
||||
" warnings.warn(\"xFormers is available (Attention)\")\n",
|
||||
"/storage/disk0/arda/dinov2/distillation/../../dinov2/dinov2/layers/block.py:33: UserWarning: xFormers is available (Block)\n",
|
||||
" warnings.warn(\"xFormers is available (Block)\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"CustomResNet(\n",
|
||||
" (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU(inplace=True)\n",
|
||||
" (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
|
||||
" (layer1): Sequential(\n",
|
||||
" (0): Bottleneck(\n",
|
||||
" (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU(inplace=True)\n",
|
||||
" (downsample): Sequential(\n",
|
||||
" (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (1): Bottleneck(\n",
|
||||
" (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" (2): Bottleneck(\n",
|
||||
" (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (layer2): Sequential(\n",
|
||||
" (0): Bottleneck(\n",
|
||||
" (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU(inplace=True)\n",
|
||||
" (downsample): Sequential(\n",
|
||||
" (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
|
||||
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (1): Bottleneck(\n",
|
||||
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" (2): Bottleneck(\n",
|
||||
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" (3): Bottleneck(\n",
|
||||
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (layer3): Sequential(\n",
|
||||
" (0): Bottleneck(\n",
|
||||
" (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU(inplace=True)\n",
|
||||
" (downsample): Sequential(\n",
|
||||
" (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
|
||||
" (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (1): Bottleneck(\n",
|
||||
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" (2): Bottleneck(\n",
|
||||
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" (3): Bottleneck(\n",
|
||||
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" (4): Bottleneck(\n",
|
||||
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" (5): Bottleneck(\n",
|
||||
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (layer4): Sequential(\n",
|
||||
" (0): Bottleneck(\n",
|
||||
" (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU(inplace=True)\n",
|
||||
" (downsample): Sequential(\n",
|
||||
" (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
|
||||
" (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (1): Bottleneck(\n",
|
||||
" (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" (2): Bottleneck(\n",
|
||||
" (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
|
||||
" (feature_matcher): Conv2d(2048, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from models import CustomResNet\n",
|
||||
"encoder = CustomResNet()\n",
|
||||
"# encoder.load_state_dict(torch.load('student_backbone_checkpoint.pth')['backbone_state_dict'])\n",
|
||||
"\n",
|
||||
"encoder.eval()\n",
|
||||
"encoder.to(device)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"asd = torch.randn(1, 3, 512, 1024).to(device)\n",
|
||||
"encoder(asd)[\"feature_map\"].shape\n",
|
||||
"\n",
|
||||
"# Freeze all parameters of the encoder\n",
|
||||
"for param in encoder.parameters():\n",
|
||||
" param.requires_grad = True\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 625/625 [05:04<00:00, 2.05it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"Epoch 1/10\n",
|
||||
"Loss: 1.0042\n",
|
||||
"Pixel Accuracy: 0.8131\n",
|
||||
"Mean Class Accuracy: 0.3035\n",
|
||||
"Mean IoU: 0.2119\n",
|
||||
"\n",
|
||||
"Per-class metrics:\n",
|
||||
"Class 0 - Acc: 0.9479, IoU: 0.9112\n",
|
||||
"Class 1 - Acc: 0.4739, IoU: 0.3257\n",
|
||||
"Class 2 - Acc: 0.6635, IoU: 0.5876\n",
|
||||
"Class 3 - Acc: 0.3823, IoU: 0.0700\n",
|
||||
"Class 4 - Acc: 0.0076, IoU: 0.0012\n",
|
||||
"Class 5 - Acc: 0.0200, IoU: 0.0007\n",
|
||||
"Class 6 - Acc: 0.0012, IoU: 0.0005\n",
|
||||
"Class 7 - Acc: 0.0030, IoU: 0.0016\n",
|
||||
"Class 8 - Acc: 0.6669, IoU: 0.4891\n",
|
||||
"Class 9 - Acc: 0.6008, IoU: 0.2622\n",
|
||||
"Class 10 - Acc: 0.8642, IoU: 0.8031\n",
|
||||
"Class 11 - Acc: 0.0023, IoU: 0.0003\n",
|
||||
"Class 12 - Acc: 0.0005, IoU: 0.0000\n",
|
||||
"Class 13 - Acc: 0.5293, IoU: 0.4233\n",
|
||||
"Class 14 - Acc: 0.5114, IoU: 0.1243\n",
|
||||
"Class 15 - Acc: 0.0106, IoU: 0.0013\n",
|
||||
"Class 16 - Acc: 0.0783, IoU: 0.0213\n",
|
||||
"Class 17 - Acc: 0.0027, IoU: 0.0020\n",
|
||||
"Class 18 - Acc: 0.0001, IoU: 0.0001\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 625/625 [05:02<00:00, 2.06it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"Epoch 2/10\n",
|
||||
"Loss: 0.5420\n",
|
||||
"Pixel Accuracy: 0.8578\n",
|
||||
"Mean Class Accuracy: 0.3853\n",
|
||||
"Mean IoU: 0.2714\n",
|
||||
"\n",
|
||||
"Per-class metrics:\n",
|
||||
"Class 0 - Acc: 0.9575, IoU: 0.9333\n",
|
||||
"Class 1 - Acc: 0.6296, IoU: 0.4457\n",
|
||||
"Class 2 - Acc: 0.7383, IoU: 0.6627\n",
|
||||
"Class 3 - Acc: 0.5219, IoU: 0.2756\n",
|
||||
"Class 4 - Acc: 0.0285, IoU: 0.0000\n",
|
||||
"Class 5 - Acc: 0.5090, IoU: 0.0061\n",
|
||||
"Class 6 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"Class 7 - Acc: 0.0207, IoU: 0.0028\n",
|
||||
"Class 8 - Acc: 0.7244, IoU: 0.5757\n",
|
||||
"Class 9 - Acc: 0.6691, IoU: 0.4452\n",
|
||||
"Class 10 - Acc: 0.8985, IoU: 0.8491\n",
|
||||
"Class 11 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"Class 12 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"Class 13 - Acc: 0.7140, IoU: 0.5871\n",
|
||||
"Class 14 - Acc: 0.5454, IoU: 0.3653\n",
|
||||
"Class 15 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"Class 16 - Acc: 0.2911, IoU: 0.0080\n",
|
||||
"Class 17 - Acc: 0.0736, IoU: 0.0001\n",
|
||||
"Class 18 - Acc: 0.0000, IoU: 0.0000\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 625/625 [05:09<00:00, 2.02it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"Epoch 3/10\n",
|
||||
"Loss: 0.4474\n",
|
||||
"Pixel Accuracy: 0.8694\n",
|
||||
"Mean Class Accuracy: 0.4404\n",
|
||||
"Mean IoU: 0.2925\n",
|
||||
"\n",
|
||||
"Per-class metrics:\n",
|
||||
"Class 0 - Acc: 0.9605, IoU: 0.9393\n",
|
||||
"Class 1 - Acc: 0.6845, IoU: 0.4889\n",
|
||||
"Class 2 - Acc: 0.7617, IoU: 0.6874\n",
|
||||
"Class 3 - Acc: 0.5666, IoU: 0.3313\n",
|
||||
"Class 4 - Acc: 0.5337, IoU: 0.0441\n",
|
||||
"Class 5 - Acc: 0.4405, IoU: 0.0674\n",
|
||||
"Class 6 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"Class 7 - Acc: 0.0418, IoU: 0.0063\n",
|
||||
"Class 8 - Acc: 0.7413, IoU: 0.5980\n",
|
||||
"Class 9 - Acc: 0.7017, IoU: 0.4898\n",
|
||||
"Class 10 - Acc: 0.9103, IoU: 0.8633\n",
|
||||
"Class 11 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"Class 12 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"Class 13 - Acc: 0.7591, IoU: 0.6366\n",
|
||||
"Class 14 - Acc: 0.5490, IoU: 0.3969\n",
|
||||
"Class 15 - Acc: 0.0137, IoU: 0.0000\n",
|
||||
"Class 16 - Acc: 0.1809, IoU: 0.0080\n",
|
||||
"Class 17 - Acc: 0.5217, IoU: 0.0001\n",
|
||||
"Class 18 - Acc: 0.0000, IoU: 0.0000\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 625/625 [05:21<00:00, 1.94it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"Epoch 4/10\n",
|
||||
"Loss: 0.3998\n",
|
||||
"Pixel Accuracy: 0.8774\n",
|
||||
"Mean Class Accuracy: 0.4623\n",
|
||||
"Mean IoU: 0.3114\n",
|
||||
"\n",
|
||||
"Per-class metrics:\n",
|
||||
"Class 0 - Acc: 0.9630, IoU: 0.9424\n",
|
||||
"Class 1 - Acc: 0.6986, IoU: 0.5060\n",
|
||||
"Class 2 - Acc: 0.7858, IoU: 0.7109\n",
|
||||
"Class 3 - Acc: 0.6149, IoU: 0.3757\n",
|
||||
"Class 4 - Acc: 0.5206, IoU: 0.1590\n",
|
||||
"Class 5 - Acc: 0.4578, IoU: 0.1030\n",
|
||||
"Class 6 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"Class 7 - Acc: 0.0656, IoU: 0.0054\n",
|
||||
"Class 8 - Acc: 0.7545, IoU: 0.6145\n",
|
||||
"Class 9 - Acc: 0.7230, IoU: 0.5221\n",
|
||||
"Class 10 - Acc: 0.9161, IoU: 0.8715\n",
|
||||
"Class 11 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"Class 12 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"Class 13 - Acc: 0.7828, IoU: 0.6665\n",
|
||||
"Class 14 - Acc: 0.5560, IoU: 0.4277\n",
|
||||
"Class 15 - Acc: 0.7750, IoU: 0.0002\n",
|
||||
"Class 16 - Acc: 0.1692, IoU: 0.0113\n",
|
||||
"Class 17 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"Class 18 - Acc: 0.0000, IoU: 0.0000\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 625/625 [05:18<00:00, 1.97it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"Epoch 5/10\n",
|
||||
"Loss: 0.3691\n",
|
||||
"Pixel Accuracy: 0.8840\n",
|
||||
"Mean Class Accuracy: 0.5277\n",
|
||||
"Mean IoU: 0.3295\n",
|
||||
"\n",
|
||||
"Per-class metrics:\n",
|
||||
"Class 0 - Acc: 0.9657, IoU: 0.9464\n",
|
||||
"Class 1 - Acc: 0.7189, IoU: 0.5318\n",
|
||||
"Class 2 - Acc: 0.7985, IoU: 0.7254\n",
|
||||
"Class 3 - Acc: 0.6484, IoU: 0.4077\n",
|
||||
"Class 4 - Acc: 0.5407, IoU: 0.2147\n",
|
||||
"Class 5 - Acc: 0.4760, IoU: 0.1247\n",
|
||||
"Class 6 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"Class 7 - Acc: 0.1132, IoU: 0.0053\n",
|
||||
"Class 8 - Acc: 0.7630, IoU: 0.6253\n",
|
||||
"Class 9 - Acc: 0.7370, IoU: 0.5430\n",
|
||||
"Class 10 - Acc: 0.9210, IoU: 0.8773\n",
|
||||
"Class 11 - Acc: 1.0000, IoU: 0.0000\n",
|
||||
"Class 12 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"Class 13 - Acc: 0.7987, IoU: 0.6872\n",
|
||||
"Class 14 - Acc: 0.6004, IoU: 0.4685\n",
|
||||
"Class 15 - Acc: 0.5672, IoU: 0.0035\n",
|
||||
"Class 16 - Acc: 0.3776, IoU: 0.0990\n",
|
||||
"Class 17 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"Class 18 - Acc: 0.0000, IoU: 0.0000\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 625/625 [05:21<00:00, 1.95it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"Epoch 6/10\n",
|
||||
"Loss: 0.3422\n",
|
||||
"Pixel Accuracy: 0.8916\n",
|
||||
"Mean Class Accuracy: 0.5649\n",
|
||||
"Mean IoU: 0.3674\n",
|
||||
"\n",
|
||||
"Per-class metrics:\n",
|
||||
"Class 0 - Acc: 0.9678, IoU: 0.9492\n",
|
||||
"Class 1 - Acc: 0.7343, IoU: 0.5554\n",
|
||||
"Class 2 - Acc: 0.8114, IoU: 0.7392\n",
|
||||
"Class 3 - Acc: 0.6759, IoU: 0.4458\n",
|
||||
"Class 4 - Acc: 0.5666, IoU: 0.2546\n",
|
||||
"Class 5 - Acc: 0.4890, IoU: 0.1428\n",
|
||||
"Class 6 - Acc: 0.0053, IoU: 0.0000\n",
|
||||
"Class 7 - Acc: 0.3257, IoU: 0.0115\n",
|
||||
"Class 8 - Acc: 0.7729, IoU: 0.6382\n",
|
||||
"Class 9 - Acc: 0.7548, IoU: 0.5681\n",
|
||||
"Class 10 - Acc: 0.9247, IoU: 0.8828\n",
|
||||
"Class 11 - Acc: 0.8371, IoU: 0.0013\n",
|
||||
"Class 12 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"Class 13 - Acc: 0.8124, IoU: 0.7043\n",
|
||||
"Class 14 - Acc: 0.7332, IoU: 0.5560\n",
|
||||
"Class 15 - Acc: 0.7383, IoU: 0.1539\n",
|
||||
"Class 16 - Acc: 0.5842, IoU: 0.3779\n",
|
||||
"Class 17 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"Class 18 - Acc: 0.0000, IoU: 0.0000\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 625/625 [05:01<00:00, 2.07it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"Epoch 7/10\n",
|
||||
"Loss: 0.3161\n",
|
||||
"Pixel Accuracy: 0.8988\n",
|
||||
"Mean Class Accuracy: 0.6234\n",
|
||||
"Mean IoU: 0.4084\n",
|
||||
"\n",
|
||||
"Per-class metrics:\n",
|
||||
"Class 0 - Acc: 0.9703, IoU: 0.9527\n",
|
||||
"Class 1 - Acc: 0.7466, IoU: 0.5767\n",
|
||||
"Class 2 - Acc: 0.8263, IoU: 0.7558\n",
|
||||
"Class 3 - Acc: 0.7073, IoU: 0.4823\n",
|
||||
"Class 4 - Acc: 0.6031, IoU: 0.3065\n",
|
||||
"Class 5 - Acc: 0.5035, IoU: 0.1583\n",
|
||||
"Class 6 - Acc: 0.5900, IoU: 0.0231\n",
|
||||
"Class 7 - Acc: 0.7235, IoU: 0.1230\n",
|
||||
"Class 8 - Acc: 0.7802, IoU: 0.6478\n",
|
||||
"Class 9 - Acc: 0.7666, IoU: 0.5857\n",
|
||||
"Class 10 - Acc: 0.9281, IoU: 0.8875\n",
|
||||
"Class 11 - Acc: 0.6169, IoU: 0.0523\n",
|
||||
"Class 12 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"Class 13 - Acc: 0.8327, IoU: 0.7291\n",
|
||||
"Class 14 - Acc: 0.8073, IoU: 0.6301\n",
|
||||
"Class 15 - Acc: 0.7692, IoU: 0.3528\n",
|
||||
"Class 16 - Acc: 0.6729, IoU: 0.4956\n",
|
||||
"Class 17 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"Class 18 - Acc: 0.0000, IoU: 0.0000\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 625/625 [05:01<00:00, 2.07it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"Epoch 8/10\n",
|
||||
"Loss: 0.2998\n",
|
||||
"Pixel Accuracy: 0.9030\n",
|
||||
"Mean Class Accuracy: 0.6303\n",
|
||||
"Mean IoU: 0.4445\n",
|
||||
"\n",
|
||||
"Per-class metrics:\n",
|
||||
"Class 0 - Acc: 0.9711, IoU: 0.9539\n",
|
||||
"Class 1 - Acc: 0.7571, IoU: 0.5845\n",
|
||||
"Class 2 - Acc: 0.8360, IoU: 0.7655\n",
|
||||
"Class 3 - Acc: 0.7286, IoU: 0.5109\n",
|
||||
"Class 4 - Acc: 0.6284, IoU: 0.3354\n",
|
||||
"Class 5 - Acc: 0.5278, IoU: 0.1734\n",
|
||||
"Class 6 - Acc: 0.5810, IoU: 0.0764\n",
|
||||
"Class 7 - Acc: 0.6968, IoU: 0.2643\n",
|
||||
"Class 8 - Acc: 0.7860, IoU: 0.6541\n",
|
||||
"Class 9 - Acc: 0.7736, IoU: 0.6015\n",
|
||||
"Class 10 - Acc: 0.9303, IoU: 0.8904\n",
|
||||
"Class 11 - Acc: 0.5340, IoU: 0.1885\n",
|
||||
"Class 12 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"Class 13 - Acc: 0.8442, IoU: 0.7447\n",
|
||||
"Class 14 - Acc: 0.8279, IoU: 0.6598\n",
|
||||
"Class 15 - Acc: 0.8079, IoU: 0.4820\n",
|
||||
"Class 16 - Acc: 0.7457, IoU: 0.5606\n",
|
||||
"Class 17 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"Class 18 - Acc: 0.0000, IoU: 0.0000\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 625/625 [05:06<00:00, 2.04it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"Epoch 9/10\n",
|
||||
"Loss: 0.2808\n",
|
||||
"Pixel Accuracy: 0.9083\n",
|
||||
"Mean Class Accuracy: 0.6407\n",
|
||||
"Mean IoU: 0.4691\n",
|
||||
"\n",
|
||||
"Per-class metrics:\n",
|
||||
"Class 0 - Acc: 0.9735, IoU: 0.9573\n",
|
||||
"Class 1 - Acc: 0.7711, IoU: 0.6084\n",
|
||||
"Class 2 - Acc: 0.8466, IoU: 0.7781\n",
|
||||
"Class 3 - Acc: 0.7457, IoU: 0.5347\n",
|
||||
"Class 4 - Acc: 0.6598, IoU: 0.3831\n",
|
||||
"Class 5 - Acc: 0.5488, IoU: 0.1873\n",
|
||||
"Class 6 - Acc: 0.5466, IoU: 0.1055\n",
|
||||
"Class 7 - Acc: 0.6838, IoU: 0.3030\n",
|
||||
"Class 8 - Acc: 0.7932, IoU: 0.6649\n",
|
||||
"Class 9 - Acc: 0.7872, IoU: 0.6204\n",
|
||||
"Class 10 - Acc: 0.9330, IoU: 0.8935\n",
|
||||
"Class 11 - Acc: 0.5515, IoU: 0.2404\n",
|
||||
"Class 12 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"Class 13 - Acc: 0.8499, IoU: 0.7519\n",
|
||||
"Class 14 - Acc: 0.8496, IoU: 0.6852\n",
|
||||
"Class 15 - Acc: 0.8560, IoU: 0.5703\n",
|
||||
"Class 16 - Acc: 0.7770, IoU: 0.6282\n",
|
||||
"Class 17 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"Class 18 - Acc: 0.0000, IoU: 0.0000\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 625/625 [05:02<00:00, 2.07it/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"Epoch 10/10\n",
|
||||
"Loss: 0.2632\n",
|
||||
"Pixel Accuracy: 0.9133\n",
|
||||
"Mean Class Accuracy: 0.6571\n",
|
||||
"Mean IoU: 0.4930\n",
|
||||
"\n",
|
||||
"Per-class metrics:\n",
|
||||
"Class 0 - Acc: 0.9747, IoU: 0.9590\n",
|
||||
"Class 1 - Acc: 0.7830, IoU: 0.6237\n",
|
||||
"Class 2 - Acc: 0.8578, IoU: 0.7914\n",
|
||||
"Class 3 - Acc: 0.7660, IoU: 0.5668\n",
|
||||
"Class 4 - Acc: 0.6851, IoU: 0.4189\n",
|
||||
"Class 5 - Acc: 0.5689, IoU: 0.2065\n",
|
||||
"Class 6 - Acc: 0.5763, IoU: 0.1325\n",
|
||||
"Class 7 - Acc: 0.7205, IoU: 0.3465\n",
|
||||
"Class 8 - Acc: 0.7998, IoU: 0.6733\n",
|
||||
"Class 9 - Acc: 0.7946, IoU: 0.6361\n",
|
||||
"Class 10 - Acc: 0.9361, IoU: 0.8980\n",
|
||||
"Class 11 - Acc: 0.5662, IoU: 0.2605\n",
|
||||
"Class 12 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"Class 13 - Acc: 0.8638, IoU: 0.7721\n",
|
||||
"Class 14 - Acc: 0.8702, IoU: 0.7258\n",
|
||||
"Class 15 - Acc: 0.8836, IoU: 0.6627\n",
|
||||
"Class 16 - Acc: 0.8378, IoU: 0.6929\n",
|
||||
"Class 17 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"Class 18 - Acc: 0.0000, IoU: 0.0000\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# First, let's create a simple decoder network\n",
|
||||
"import numpy as np\n",
|
||||
"import tqdm as tqdm\n",
|
||||
"class SegmentationDecoder(torch.nn.Module):\n",
|
||||
" def __init__(self, in_channels=2048, num_classes=19):\n",
|
||||
" super().__init__()\n",
|
||||
" self.decoder = torch.nn.Sequential(\n",
|
||||
" # 16x32 -> 32x64\n",
|
||||
" torch.nn.ConvTranspose2d(in_channels, 1024, kernel_size=4, stride=2, padding=1),\n",
|
||||
" torch.nn.BatchNorm2d(1024),\n",
|
||||
" torch.nn.ReLU(),\n",
|
||||
" \n",
|
||||
" # 32x64 -> 64x128\n",
|
||||
" torch.nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1),\n",
|
||||
" torch.nn.BatchNorm2d(512),\n",
|
||||
" torch.nn.ReLU(),\n",
|
||||
" \n",
|
||||
" # 64x128 -> 128x256\n",
|
||||
" torch.nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),\n",
|
||||
" torch.nn.BatchNorm2d(256),\n",
|
||||
" torch.nn.ReLU(),\n",
|
||||
" \n",
|
||||
" # 128x256 -> 256x512\n",
|
||||
" torch.nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),\n",
|
||||
" torch.nn.BatchNorm2d(128),\n",
|
||||
" torch.nn.ReLU(),\n",
|
||||
" \n",
|
||||
" # 256x512 -> 512x1024\n",
|
||||
" torch.nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),\n",
|
||||
" torch.nn.BatchNorm2d(64),\n",
|
||||
" torch.nn.ReLU(),\n",
|
||||
" \n",
|
||||
" # Final 1x1 conv to get to num_classes\n",
|
||||
" torch.nn.Conv2d(64, num_classes, kernel_size=1)\n",
|
||||
" )\n",
|
||||
" def forward(self, x):\n",
|
||||
" x = self.decoder(x)\n",
|
||||
" # Ensure exact output size\n",
|
||||
" if x.shape[-2:] != (512, 1024):\n",
|
||||
" x = torch.nn.functional.interpolate(\n",
|
||||
" x, size=(512, 1024), \n",
|
||||
" mode='bilinear', \n",
|
||||
" align_corners=False\n",
|
||||
" )\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"# Initialize decoder, optimizer, and loss function\n",
|
||||
"decoder = SegmentationDecoder().to(device)\n",
|
||||
"optimizer = torch.optim.Adam(decoder.parameters(), lr=1e-4)\n",
|
||||
"criterion = torch.nn.CrossEntropyLoss(ignore_index=255)\n",
|
||||
"\n",
|
||||
"def fast_hist(a: np.ndarray, b: np.ndarray, n: int) -> np.ndarray:\n",
|
||||
" k = (b >= 0) & (b < n)\n",
|
||||
" return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)\n",
|
||||
"\n",
|
||||
"def per_class_iou(hist: np.ndarray) -> np.ndarray:\n",
|
||||
" epsilon = 1e-5\n",
|
||||
" return (np.diag(hist)) / (hist.sum(1) + hist.sum(0) - np.diag(hist) + epsilon)\n",
|
||||
"\n",
|
||||
"def train_epoch(encoder, decoder, dataloader, optimizer, criterion, device, num_classes=19):\n",
|
||||
" decoder.train()\n",
|
||||
" encoder.train() # Keep DINO frozen\n",
|
||||
" \n",
|
||||
" total_loss = 0\n",
|
||||
" hist = np.zeros((num_classes, num_classes)) # Single histogram for entire epoch\n",
|
||||
" total_pixels = 0\n",
|
||||
" correct_pixels = 0\n",
|
||||
" \n",
|
||||
" for images, labels in tqdm.tqdm(dataloader):\n",
|
||||
" images = images.to(device)\n",
|
||||
" labels = labels.to(device)\n",
|
||||
" \n",
|
||||
" # Get DINO features\n",
|
||||
" # with torch.no_grad():\n",
|
||||
" features = encoder(images)[\"feature_map\"]\n",
|
||||
" \n",
|
||||
" # Forward pass through decoder\n",
|
||||
" outputs = decoder(features)\n",
|
||||
" \n",
|
||||
" # Resize outputs to match label size if needed\n",
|
||||
" if outputs.shape[-2:] != labels.shape[-2:]:\n",
|
||||
" outputs = torch.nn.functional.interpolate(\n",
|
||||
" outputs, size=labels.shape[-2:], mode='bilinear', align_corners=False)\n",
|
||||
" \n",
|
||||
" # Calculate loss\n",
|
||||
" loss = criterion(outputs, labels)\n",
|
||||
" \n",
|
||||
" # Backward pass\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
" \n",
|
||||
" total_loss += loss.item()\n",
|
||||
" \n",
|
||||
" # Calculate metrics\n",
|
||||
" preds = torch.argmax(torch.softmax(outputs, dim=1), dim=1)\n",
|
||||
" \n",
|
||||
" # Pixel Accuracy\n",
|
||||
" valid_mask = labels != 255 # Ignore index\n",
|
||||
" total_pixels += valid_mask.sum().item()\n",
|
||||
" correct_pixels += ((preds == labels) & valid_mask).sum().item()\n",
|
||||
" \n",
|
||||
" # IoU\n",
|
||||
" preds = preds.cpu().numpy()\n",
|
||||
" target = labels.cpu().numpy()\n",
|
||||
" hist += fast_hist(preds.flatten(), target.flatten(), num_classes)\n",
|
||||
" \n",
|
||||
" # Calculate final metrics\n",
|
||||
" pixel_acc = correct_pixels / total_pixels\n",
|
||||
" \n",
|
||||
" # Per-class accuracy (mean class accuracy)\n",
|
||||
" class_acc = np.diag(hist) / (hist.sum(1) + np.finfo(np.float32).eps)\n",
|
||||
" mean_class_acc = np.nanmean(class_acc)\n",
|
||||
" \n",
|
||||
" # IoU metrics\n",
|
||||
" iou = per_class_iou(hist)\n",
|
||||
" mean_iou = np.nanmean(iou)\n",
|
||||
" \n",
|
||||
" metrics = {\n",
|
||||
" 'loss': total_loss / len(dataloader),\n",
|
||||
" 'pixel_acc': pixel_acc,\n",
|
||||
" 'mean_class_acc': mean_class_acc,\n",
|
||||
" 'mean_iou': mean_iou,\n",
|
||||
" 'class_iou': iou,\n",
|
||||
" 'class_acc': class_acc\n",
|
||||
" }\n",
|
||||
" \n",
|
||||
" return metrics\n",
|
||||
"\n",
|
||||
"train_loader = torch.utils.data.DataLoader(\n",
|
||||
" GTA5_dataset, \n",
|
||||
" batch_size=4,\n",
|
||||
" shuffle=True,\n",
|
||||
" num_workers=4\n",
|
||||
")\n",
|
||||
"# Training loop\n",
|
||||
"num_epochs = 10\n",
|
||||
"for epoch in range(num_epochs):\n",
|
||||
" metrics = train_epoch(encoder, decoder, train_loader, optimizer, criterion, device)\n",
|
||||
" \n",
|
||||
" print(f\"\\nEpoch {epoch+1}/{num_epochs}\")\n",
|
||||
" print(f\"Loss: {metrics['loss']:.4f}\")\n",
|
||||
" print(f\"Pixel Accuracy: {metrics['pixel_acc']:.4f}\")\n",
|
||||
" print(f\"Mean Class Accuracy: {metrics['mean_class_acc']:.4f}\")\n",
|
||||
" print(f\"Mean IoU: {metrics['mean_iou']:.4f}\")\n",
|
||||
" \n",
|
||||
" # Optionally print per-class metrics\n",
|
||||
" print(\"\\nPer-class metrics:\")\n",
|
||||
" for i in range(19): # Assuming 19 classes\n",
|
||||
" print(f\"Class {i:2d} - Acc: {metrics['class_acc'][i]:.4f}, IoU: {metrics['class_iou'][i]:.4f}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Epoch 3/10\n",
|
||||
"# Loss: 1.2027\n",
|
||||
"# Pixel Accuracy: 0.6440\n",
|
||||
"# Mean Class Accuracy: 0.1213\n",
|
||||
"# Mean IoU: 0.0818\n",
|
||||
"\n",
|
||||
"# Per-class metrics:\n",
|
||||
"# Class 0 - Acc: 0.6736, IoU: 0.6670\n",
|
||||
"# Class 1 - Acc: 0.1000, IoU: 0.0000\n",
|
||||
"# Class 2 - Acc: 0.3823, IoU: 0.1999\n",
|
||||
"# Class 3 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"# Class 4 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"# Class 5 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"# Class 6 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"# Class 7 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"# Class 8 - Acc: 0.4519, IoU: 0.0889\n",
|
||||
"# Class 9 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"# Class 10 - Acc: 0.6970, IoU: 0.5976\n",
|
||||
"# Class 11 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"# Class 12 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"# Class 13 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"# Class 14 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"# Class 15 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"# Class 16 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"# Class 17 - Acc: 0.0000, IoU: 0.0000\n",
|
||||
"# Class 18 - Acc: 0.0000, IoU: 0.0000"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "ImportError",
|
||||
"evalue": "cannot import name 'get_id_to_color' from 'datasets' (unknown location)",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[7], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpyplot\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mplt\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mdatasets\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m get_id_to_color \n\u001b[1;32m 5\u001b[0m img_idx \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m310\u001b[39m\n\u001b[1;32m 7\u001b[0m embeddings \u001b[38;5;241m=\u001b[39m encoder\u001b[38;5;241m.\u001b[39mget_intermediate_layers(GTA5_dataset[img_idx][\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mto(device), n\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m, reshape\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, return_class_token\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, norm\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)[\u001b[38;5;241m0\u001b[39m]\n",
|
||||
"\u001b[0;31mImportError\u001b[0m: cannot import name 'get_id_to_color' from 'datasets' (unknown location)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from datasets import get_id_to_color \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"img_idx = 310\n",
|
||||
"\n",
|
||||
"embeddings = encoder.get_intermediate_layers(GTA5_dataset[img_idx][0].unsqueeze(0).to(device), n=1, reshape=True, return_class_token=False, norm=False)[0]\n",
|
||||
"out = decoder(embeddings)\n",
|
||||
"id_to_color = get_id_to_color()\n",
|
||||
"\n",
|
||||
"pred = out.argmax(1).cpu().numpy()\n",
|
||||
"pred = pred.reshape(518, 1036)\n",
|
||||
"# Convert class IDs to RGB colors\n",
|
||||
"color_map = np.array([id_to_color.get(i, (0, 0, 0)) for i in range(max(id_to_color.keys()) + 1)])\n",
|
||||
"pred_rgb = color_map[pred]\n",
|
||||
"\n",
|
||||
"plt.figure(figsize=(10, 10))\n",
|
||||
"plt.imshow(pred_rgb)\n",
|
||||
"\n",
|
||||
"plt.figure(figsize=(10, 10))\n",
|
||||
"labels = GTA5_dataset[img_idx][1].cpu().numpy()\n",
|
||||
"color_map = np.array([id_to_color.get(i, (0, 0, 0)) for i in range(max(id_to_color.keys()) + 1)])\n",
|
||||
"pred_rgb = np.zeros((*labels.shape, 3), dtype=np.uint8)\n",
|
||||
"mask = labels < len(color_map)\n",
|
||||
"pred_rgb[mask] = color_map[labels[mask]]\n",
|
||||
"plt.imshow(pred_rgb)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"torch.Size([518, 1036])"
|
||||
]
|
||||
},
|
||||
"execution_count": 28,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"GTA5_dataset[img_idx][1].shape\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "dinov2",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.20"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
File diff suppressed because one or more lines are too long
|
@ -1,70 +0,0 @@
|
|||
INFO:root:Epoch 1
|
||||
INFO:root:Train Loss: 1.5444
|
||||
INFO:root:Train Similarity: 0.3211
|
||||
INFO:root:Test Loss: 1.4806
|
||||
INFO:root:Test Similarity: 0.3588
|
||||
INFO:root:Epoch 2
|
||||
INFO:root:Train Loss: 1.4716
|
||||
INFO:root:Train Similarity: 0.3646
|
||||
INFO:root:Test Loss: 1.4506
|
||||
INFO:root:Test Similarity: 0.3767
|
||||
INFO:root:Epoch 3
|
||||
INFO:root:Train Loss: 1.4467
|
||||
INFO:root:Train Similarity: 0.3790
|
||||
INFO:root:Test Loss: 1.4322
|
||||
INFO:root:Test Similarity: 0.3869
|
||||
INFO:root:Epoch 4
|
||||
INFO:root:Train Loss: 1.4311
|
||||
INFO:root:Train Similarity: 0.3878
|
||||
INFO:root:Test Loss: 1.4236
|
||||
INFO:root:Test Similarity: 0.3920
|
||||
INFO:root:Epoch 5
|
||||
INFO:root:Train Loss: 1.4195
|
||||
INFO:root:Train Similarity: 0.3941
|
||||
INFO:root:Test Loss: 1.4170
|
||||
INFO:root:Test Similarity: 0.3956
|
||||
INFO:root:Epoch 6
|
||||
INFO:root:Train Loss: 1.4102
|
||||
INFO:root:Train Similarity: 0.3993
|
||||
INFO:root:Test Loss: 1.4017
|
||||
INFO:root:Test Similarity: 0.4038
|
||||
INFO:root:Epoch 7
|
||||
INFO:root:Train Loss: 1.4021
|
||||
INFO:root:Train Similarity: 0.4037
|
||||
INFO:root:Test Loss: 1.3943
|
||||
INFO:root:Test Similarity: 0.4079
|
||||
INFO:root:Epoch 8
|
||||
INFO:root:Train Loss: 1.3956
|
||||
INFO:root:Train Similarity: 0.4072
|
||||
INFO:root:Test Loss: 1.3929
|
||||
INFO:root:Test Similarity: 0.4091
|
||||
INFO:root:Epoch 9
|
||||
INFO:root:Train Loss: 1.3894
|
||||
INFO:root:Train Similarity: 0.4106
|
||||
INFO:root:Test Loss: 1.3853
|
||||
INFO:root:Test Similarity: 0.4128
|
||||
INFO:root:Epoch 10
|
||||
INFO:root:Train Loss: 1.3839
|
||||
INFO:root:Train Similarity: 0.4136
|
||||
INFO:root:Test Loss: 1.3811
|
||||
INFO:root:Test Similarity: 0.4154
|
||||
INFO:root:Epoch 11
|
||||
INFO:root:Train Loss: 1.3789
|
||||
INFO:root:Train Similarity: 0.4162
|
||||
INFO:root:Test Loss: 1.3791
|
||||
INFO:root:Test Similarity: 0.4163
|
||||
INFO:root:Epoch 12
|
||||
INFO:root:Train Loss: 1.3740
|
||||
INFO:root:Train Similarity: 0.4189
|
||||
INFO:root:Test Loss: 1.3732
|
||||
INFO:root:Test Similarity: 0.4193
|
||||
INFO:root:Epoch 13
|
||||
INFO:root:Train Loss: 1.3725
|
||||
INFO:root:Train Similarity: 0.4197
|
||||
INFO:root:Test Loss: 1.3592
|
||||
INFO:root:Test Similarity: 0.4268
|
||||
INFO:root:Epoch 14
|
||||
INFO:root:Train Loss: 1.3692
|
||||
INFO:root:Train Similarity: 0.4215
|
||||
INFO:root:Test Loss: 1.3625
|
||||
INFO:root:Test Similarity: 0.4264
|
|
@ -1,40 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
def match_vit_features(feature_map, patch_size, height, width):
|
||||
"""
|
||||
Match the feature map to the ViT patch grid.
|
||||
|
||||
Args:
|
||||
feature_map (torch.Tensor) (B, C, H, W): The feature map to match.
|
||||
patch_size (int): The patch size of the ViT model.
|
||||
height (int): The height of the image.
|
||||
width (int): The width of the image.
|
||||
|
||||
Returns:
|
||||
patches: torch.Tensor (B, N, C): The feature map matched to the ViT patch grid.
|
||||
"""
|
||||
B, C, H, W = feature_map.shape
|
||||
H_patch = height // patch_size
|
||||
W_patch = width // patch_size
|
||||
|
||||
# Interpolate to match ViT patch grid
|
||||
feature_map = nn.functional.interpolate(
|
||||
feature_map,
|
||||
size=(H_patch, W_patch),
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
# Reshape into patches and project to match ViT dimension
|
||||
patches = feature_map.permute(0, 2, 3, 1) # [B, H, W, C]
|
||||
patches = patches.reshape(B, H_patch * W_patch, C) # [B, N, C]
|
||||
|
||||
|
||||
return patches
|
||||
|
||||
def match_feature_map(feature_map, patch_size, height, width):
|
||||
B, C, H, W = feature_map.shape
|
||||
H_patch = height // patch_size
|
||||
W_patch = width // patch_size
|
||||
return feature_map.reshape(B, C, H_patch, W_patch)
|
Loading…
Reference in New Issue