Remove distillation folder

pull/493/head
arda 2024-12-01 03:30:00 +00:00
parent 33172fba61
commit c38640be55
21 changed files with 0 additions and 4850 deletions

View File

@ -1,45 +0,0 @@
student:
model_name: resnet
teacher:
model_name: dinov2_vitg14
out_dim: 1536
data_transform:
n_global_crops: 2
n_local_crops: 8
global_crops_scale: [0.32, 1.0]
local_crops_scale: [0.05, 0.32]
global_crops_size: [224, 224]
local_crops_size: [224, 224]
data_loader:
batch_size: 32
num_workers: 8
shuffle: true
collate_fn: collate_data_and_cast
optimizer:
type: AdamW
lr: 0.0005
dino_loss:
weight: 1
student_temp: 0.1
teacher_temp: 0.07
train:
max_epochs: 100
save_checkpoint_freq: 10
resume_checkpoint: null
model_wrapper:
model_type: resnet
n_patches: 256
target_feature:
- res5
feature_matcher_config:
out_channels: 1536

View File

@ -1,31 +0,0 @@
import os
import torch
from torch.utils.data import Dataset
from PIL import Image
from typing import Optional
from torchvision import transforms
class GTA5Dataset(Dataset):
def __init__(self, img_dir: str = "/home/arda/.cache/kagglehub/datasets/ardaerendoru/gtagta/versions/1/GTA5/GTA5/images", transform: Optional[transforms.Compose] = None):
self.img_dir = img_dir
self.transform = transform
# Get all image files
self.images = []
for img_name in os.listdir(self.img_dir):
if img_name.endswith(('.jpg', '.png')):
self.images.append(os.path.join(self.img_dir, img_name))
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
return image

View File

@ -1,13 +0,0 @@
import torch
def collate_data_and_cast(samples_list):
n_global_crops = len(samples_list[0]["global_crops"])
n_local_crops = len(samples_list[0]["local_crops"])
collated_global_crops = torch.stack([s["global_crops"][i] for i in range(n_global_crops) for s in samples_list])
collated_local_crops = torch.stack([s["local_crops"][i] for i in range(n_local_crops) for s in samples_list])
return {
"collated_global_crops": collated_global_crops,
"collated_local_crops": collated_local_crops,
}

View File

@ -1,40 +0,0 @@
import os
import torch
from torch.utils.data import Dataset
from PIL import Image
from typing import Optional
from torchvision import transforms
from datasets import load_dataset
import datasets
import numpy as np
from torch.utils.data import Dataset, Subset
import random
class ImageNetDataset(Dataset):
def __init__(self, type='train', transform: Optional[transforms.Compose] = None, num_samples: Optional[int] = None):
self.dataset = load_dataset('imagenet-1k', trust_remote_code=True)
if type == 'train':
self.dataset = self.dataset['train']
elif type == 'validation':
self.dataset = self.dataset['validation']
else:
self.dataset = self.dataset['test']
if num_samples is not None:
# Randomly sample indices
indices = random.sample(range(len(self.dataset)), num_samples)
self.dataset = Subset(self.dataset, indices)
self.transform = transform
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
if self.transform:
image = self.dataset[idx]['image']
if image.mode != 'RGB':
image = image.convert('RGB')
return self.transform(image)
else:
return self.dataset[idx]['image']

View File

@ -1,378 +0,0 @@
from torch.utils.data import Dataset
from typing import Tuple, List, Optional
import torch
from PIL import Image
import os
import numpy as np
import random
from albumentations import Compose
import math
import itertools
import torch.nn.functional as F
class GTA5(Dataset):
"""
GTA5 Dataset class for loading and transforming GTA5 dataset images and labels for semantic segmentation tasks.
"""
def __init__(self, GTA5_path: str, transform: Optional[Compose] = None, FDA: float = None):
"""
Initializes the GTA5 dataset class.
This constructor sets up the dataset for use, optionally applying Frequency Domain Adaptation (FDA) and other transformations to the data.
Args:
GTA5_path (str): The root directory path where the GTA5 dataset is stored.
transform (callable, optional): A function/transform that takes in an image and label and returns a transformed version. Defaults to None.
FDA (float, optional): The beta value for Frequency Domain Adaptation. If None, FDA is not applied. Defaults to None.
"""
self.GTA5_path = GTA5_path
self.transform = transform
self.FDA = FDA
self.data = self._load_data()
self.color_to_id = get_color_to_id()
self.target_images = self._load_target_images() if FDA else []
def _load_data(self)->List[Tuple[str, str]]:
"""
Load data paths for GTA5 dataset images and labels.
This method walks through the directory structure of the GTA5 dataset, specifically looking for image files in the 'images' folder and corresponding label files in the 'labels' folder. It constructs a list of tuples, each containing the path to an image file and the corresponding label file.
Returns:
list: A list of tuples, each containing the path to an image file and the corresponding label file.
"""
data = []
image_dir = os.path.join(self.GTA5_path, 'images')
label_dir = os.path.join(self.GTA5_path, 'labels')
for image_filename in os.listdir(image_dir):
image_path = os.path.join(image_dir, image_filename)
label_path = os.path.join(label_dir, image_filename)
data.append((image_path, label_path))
return data
def _load_target_images(self)->List[Tuple[str, str]]:
"""
Load target images for Frequency Domain Adaptation.
This method walks through the directory structure of the Cityscapes dataset, specifically looking for image files in the 'gtFine' folder. It constructs a list of tuples, each containing the path to a label file and the corresponding image file.
Returns:
list: A list of tuples, each containing the path to a label file and the corresponding image file.
"""
target_images = []
city_path = self.GTA5_path.replace('GTA5', 'Cityscapes')
city_image_dir = os.path.join(city_path, 'Cityspaces', 'gtFine', 'train')
for root, _, files in os.walk(city_image_dir):
for file in files:
if 'Id' in file:
label_path = os.path.join(root, file)
image_path = label_path.replace('gtFine/', 'images/').replace('_gtFine_labelTrainIds', '_leftImg8bit')
target_images.append((label_path, image_path))
return target_images
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get the image and label at the specified index.
Args:
index (int): The index of the data point to retrieve.
Returns:
tuple: A tuple containing the transformed image and label.
"""
img_path, label_path = self.data[index]
img = Image.open(img_path).convert('RGB')
label = self._convert_rgb_to_label(Image.open(label_path).convert('RGB'))
img, label = np.array(img), np.array(label)
center_padding = CenterPadding(14)
if self.FDA:
target_image_path = random.choice(self.target_images)[1]
target_image = Image.open(target_image_path).convert('RGB').resize(img.shape[1::-1])
img = FDA_transform(img, np.array(target_image), beta=self.FDA)
if self.transform:
transformed = self.transform(image=img, mask=label)
img, label = transformed['image'], transformed['mask']
img = torch.from_numpy(img).permute(2, 0, 1).float()/255
label = torch.from_numpy(label).long()
return center_padding(img), center_padding(label)
def __len__(self)->int:
"""
Get the number of data points in the dataset.
Returns:
int: The number of data points in the dataset.
"""
return len(self.data)
def _convert_rgb_to_label(self, img:Image.Image)->np.ndarray:
"""
Convert RGB image to grayscale label.
Args:
img (Image.Image): The RGB image to convert to grayscale.
Returns:
np.ndarray: The grayscale label image.
"""
gray_img = Image.new('L', img.size)
label_pixels = img.load()
gray_pixels = gray_img.load()
for i in range(img.width):
for j in range(img.height):
rgb = label_pixels[i, j]
gray_pixels[i, j] = self.color_to_id.get(rgb, 255)
return gray_img
import numpy as np
from PIL import Image
import torch
import numpy as np
from scipy.ndimage import gaussian_filter
from scipy.special import erfinv
import PIL
def fast_hist(a: np.ndarray, b: np.ndarray, n: int) -> np.ndarray:
"""
Compute a fast histogram for evaluating segmentation metrics.
This function calculates a 2D histogram where each entry (i, j) counts the number of pixels that have the true label i and the predicted label j with a mask.
Args:
a (np.ndarray): An array of true labels.
b (np.ndarray): An array of predicted labels.
n (int): The number of different labels.
Returns:
np.ndarray: A 2D histogram of size (n, n).
"""
k = (b >= 0) & (b < n)
return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)
def per_class_iou(hist: np.ndarray) -> np.ndarray:
"""
Calculate the Intersection over Union (IoU) for each class.
The IoU is computed for each class using the histogram of true and predicted labels. It is defined as the ratio of the diagonal elements of the histogram to the sum of the corresponding rows and columns, adjusted by the diagonal elements and a small epsilon to avoid division by zero.
Args:
hist (np.ndarray): A 2D histogram where each entry (i, j) is the count of pixels with true label i and predicted label j.
Returns:
np.ndarray: An array containing the IoU for each class.
"""
epsilon = 1e-5
return (np.diag(hist)) / (hist.sum(1) + hist.sum(0) - np.diag(hist) + epsilon)
def poly_lr_scheduler(optimizer:torch.optim.Optimizer, init_lr:float, iter:int, lr_decay_iter:int=1,
max_iter:int=50, power:float=0.9)->float:
"""
Adjusts the learning rate of the optimizer for each iteration using a polynomial decay schedule.
This function updates the learning rate of the optimizer based on the current iteration number and a polynomial decay schedule. The learning rate is calculated using the formula:
lr = init_lr * (1 - iter/max_iter) ** power
where `init_lr` is the initial learning rate, `iter` is the current iteration number, `max_iter` is the maximum number of iterations, and `power` is the exponent used in the polynomial decay.
Args:
optimizer (torch.optim.Optimizer): The optimizer for which to adjust the learning rate.
init_lr (float): The initial learning rate.
iter (int): The current iteration number.
lr_decay_iter (int): The iteration interval after which the learning rate is decayed. Default is 1.
max_iter (int): The maximum number of iterations after which no more decay will happen.
power (float): The exponent used in the polynomial decay of the learning rate.
Returns:
float: The updated learning rate.
"""
# if iter % lr_decay_iter or iter > max_iter:
# return optimizer
# lr = init_lr*(1 - iter/max_iter)**power
lr = init_lr*(1 - iter/max_iter)**power
optimizer.param_groups[0]['lr'] = lr
return lr
def label_to_rgb(label:np.ndarray, height:int, width:int)->PIL.Image:
"""
Transforms a label matrix into a corresponding RGB image utilizing a predefined color map.
This function maps each label identifier in a two-dimensional array to a specific color, thereby generating an RGB image. This is particularly useful for visualizing segmentation results where each label corresponds to a different segment class.
Parameters:
label (np.ndarray): A two-dimensional array where each element represents a label identifier.
height (int): The desired height of the resulting RGB image.
width (int): The desired width of the resulting RGB image.
Returns:
PIL.Image: An image object representing the RGB image constructed from the label matrix.
"""
id_to_color = get_id_to_color()
height, width = label.shape
rgb_image = np.zeros((height, width, 3), dtype=np.uint8)
for i in range(height):
for j in range(width):
class_id = label[i, j]
rgb_image[i, j] = id_to_color.get(class_id, (255, 255, 255)) # Default to white if not found
pil_image = Image.fromarray(rgb_image, 'RGB')
return pil_image
def generate_cow_mask(img_size:tuple, sigma:float, p:float, batch_size:int)->np.ndarray:
"""
Generates a batch of cow masks based on a Gaussian noise model.
Parameters:
img_size (tuple): The size of the images (height, width).
sigma (float): The standard deviation of the Gaussian filter applied to the noise.
p (float): The desired proportion of the mask that should be 'cow'.
batch_size (int): The number of masks to generate.
Returns:
np.ndarray: A batch of cow masks of shape (batch_size, 1, height, width).
"""
N = np.random.normal(size=img_size)
Ns = gaussian_filter(N, sigma)
t = erfinv(p*2 - 1) * (2**0.5) * Ns.std() + Ns.mean()
masks = []
for i in range(batch_size):
masks.append((Ns > t).astype(float).reshape(1,*img_size))
return np.array(masks)
def get_id_to_label() -> dict:
"""
Returns a dictionary mapping class IDs to their corresponding labels.
Returns:
dict: A dictionary where keys are class IDs and values are labels.
"""
return {
0: 'road',
1: 'sidewalk',
2: 'building',
3: 'wall',
4: 'fence',
5: 'pole',
6: 'light',
7: 'sign',
8: 'vegetation',
9: 'terrain',
10: 'sky',
11: 'person',
12: 'rider',
13: 'car',
14: 'truck',
15: 'bus',
16: 'train',
17: 'motorcycle',
18: 'bicycle',
255: 'unlabeled'
}
def get_id_to_color() -> dict:
"""
Returns a dictionary mapping class IDs to their corresponding colors.
Returns:
dict: A dictionary where keys are class IDs and values are RGB color tuples.
"""
id_to_color = {
0: (128, 64, 128), # road
1: (244, 35, 232), # sidewalk
2: (70, 70, 70), # building
3: (102, 102, 156), # wall
4: (190, 153, 153), # fence
5: (153, 153, 153), # pole
6: (250, 170, 30), # light
7: (220, 220, 0), # sign
8: (107, 142, 35), # vegetation
9: (152, 251, 152), # terrain
10: (70, 130, 180), # sky
11: (220, 20, 60), # person
12: (255, 0, 0), # rider
13: (0, 0, 142), # car
14: (0, 0, 70), # truck
15: (0, 60, 100), # bus
16: (0, 80, 100), # train
17: (0, 0, 230), # motorcycle
18: (119, 11, 32), # bicycle
}
return id_to_color
def get_color_to_id() -> dict:
"""
Returns a dictionary mapping RGB color tuples to their corresponding class IDs.
Returns:
dict: A dictionary where keys are RGB color tuples and values are class IDs.
"""
id_to_color = get_id_to_color()
color_to_id = {color: id for id, color in id_to_color.items()}
return color_to_id
def mix(mask, data = None, target = None):
#Mix
if not (data is None):
if mask.shape[0] == data.shape[0]:
data = torch.cat([(mask[i] * data[i] + (1 - mask[i]) * data[(i + 1) % data.shape[0]]).unsqueeze(0) for i in range(data.shape[0])])
elif mask.shape[0] == data.shape[0] / 2:
data = torch.cat((torch.cat([(mask[i] * data[2 * i] + (1 - mask[i]) * data[2 * i + 1]).unsqueeze(0) for i in range(int(data.shape[0] / 2))]),
torch.cat([((1 - mask[i]) * data[2 * i] + mask[i] * data[2 * i + 1]).unsqueeze(0) for i in range(int(data.shape[0] / 2))])))
if not (target is None):
target = torch.cat([(mask[i] * target[i] + (1 - mask[i]) * target[(i + 1) % target.shape[0]]).unsqueeze(0) for i in range(target.shape[0])])
return data, target
def generate_class_mask(pred, classes):
pred, classes = torch.broadcast_tensors(pred.unsqueeze(0), classes.unsqueeze(1).unsqueeze(2))
N = pred.eq(classes).sum(0)
return N
class CenterPadding(torch.nn.Module):
def __init__(self, multiple):
super().__init__()
self.multiple = multiple
def _get_pad(self, size):
new_size = math.ceil(size / self.multiple) * self.multiple
pad_size = new_size - size
pad_size_left = pad_size // 2
pad_size_right = pad_size - pad_size_left
return pad_size_left, pad_size_right
@torch.inference_mode()
def forward(self, x):
pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
output = F.pad(x, pads)
return output

View File

@ -1,3 +0,0 @@
from .helper import get_teacher_and_student
from .backbones import CustomResNet, DINOv2ViT
from .resnet import ResNet

View File

@ -1,64 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class FeatureMixerLayer(nn.Module):
def __init__(self, in_dim, mlp_ratio=1):
super().__init__()
self.mix = nn.Sequential(
nn.LayerNorm(in_dim),
nn.Linear(in_dim, int(in_dim * mlp_ratio)),
nn.ReLU(),
nn.Linear(int(in_dim * mlp_ratio), in_dim),
)
for m in self.modules():
if isinstance(m, (nn.Linear)):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x):
return x + self.mix(x)
class MixVPR(nn.Module):
def __init__(self,
in_channels=1024,
in_h=14,
in_w=14,
out_channels=512,
mix_depth=4,
mlp_ratio=1,
out_rows=4,
) -> None:
super().__init__()
self.in_h = in_h # height of input feature maps
self.in_w = in_w # width of input feature maps
self.in_channels = in_channels # depth of input feature maps
self.out_channels = out_channels # depth wise projection dimension
self.out_rows = out_rows # row wise projection dimesion
self.mix_depth = mix_depth # L the number of stacked FeatureMixers
self.mlp_ratio = mlp_ratio # ratio of the mid projection layer in the mixer block
hw = in_h*in_w
self.mix = nn.Sequential(*[
FeatureMixerLayer(in_dim=hw, mlp_ratio=mlp_ratio)
for _ in range(self.mix_depth)
])
self.channel_proj = nn.Linear(in_channels, out_channels)
self.row_proj = nn.Linear(hw, out_rows)
def forward(self, x):
x = x.flatten(2)
x = self.mix(x)
x = x.permute(0, 2, 1)
x = self.channel_proj(x)
x = x.permute(0, 2, 1)
x = self.row_proj(x)
x = F.normalize(x.flatten(1), p=2, dim=1)
return x

View File

@ -1,131 +0,0 @@
from utils import match_vit_features
import torch
from torch import nn
import torchvision.models as models
import math
from .aggregators import MixVPR
class CustomResNet(nn.Module):
def __init__(self, model_name='resnet50', patch_size=14, pretrained=True):
super().__init__()
"""
Custom ResNet model for distillation.
Args:
model_name (str): The name of the ResNet model to use.
patch_size (int): The patch size of the ViT model.
pretrained (bool): Whether to use pretrained weights.
Returns:
dict: A dictionary containing:
feature_map: torch.Tensor (B, C, H, W): The feature map after the last layer of ResNet.
embeddings: torch.Tensor (B, D): The embeddings after global average pooling.
patches: torch.Tensor (B, N, C): The patches which match the ViT patch grid.
"""
if pretrained:
base_model = getattr(models, model_name)(weights='IMAGENET1K_V1')
else:
base_model = getattr(models, model_name)(weights=None)
self.patch_size = patch_size
# Split the model into layers
self.conv1 = base_model.conv1
self.bn1 = base_model.bn1
self.relu = base_model.relu
self.maxpool = base_model.maxpool
self.layer1 = base_model.layer1
self.layer2 = base_model.layer2
self.layer3 = base_model.layer3
self.layer4 = base_model.layer4
self.avgpool = base_model.avgpool
in_channels = 2048 if model_name in ['resnet50', 'resnet101', 'resnet152'] else 512
self.feature_matcher = nn.Conv2d(in_channels, 1536, kernel_size=1)
# self.mix = MixVPR(in_channels=int(in_channels), in_h=7, in_w=7, out_channels=512, mix_depth=4, mlp_ratio=1, out_rows=4)
self.attention = nn.MultiheadAttention(embed_dim=1536, num_heads=16, batch_first=True)
def forward(self, x):
B, C, H, W = x.shape
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
layer_3 = self.layer3(x)
layer4 = self.layer4(layer_3) # Final feature map before pooling
feature_map = torch.nn.functional.interpolate(
layer4,
size=(16, 16),
mode='bilinear',
align_corners=False
)
feature_map = self.feature_matcher(feature_map)
pooled = self.avgpool(feature_map) # Global average pooling
embeddings = torch.flatten(pooled, 1) # Flatten to get embeddings
# print(f"layer_3e shape: {layer_3.shape}")
# contrastive_embeddings = self.mix(layer4)
B, C, H, W = feature_map.shape
tokens = feature_map.view(B, C, H * W).permute(0, 2, 1) # [B, T, C]
attn_output, attn_weights = self.attention(tokens, tokens, tokens) # attn_weights: [B, T, T]
return {
'feature_map': layer4, # Final feature map after layer4
'embedding': embeddings, # Embeddings after pooling
# 'patch_embeddings': patches,
'dinov2_feature_map': feature_map,
# 'contrastive_embeddings': contrastive_embeddings,
'attn_weights': attn_weights,
}
class DINOv2ViT(nn.Module):
def __init__(self, model_name='dinov2_vitg14'):
super().__init__()
"""
DINOv2 ViT model for distillation.
Args:
model_name (str): The name of the DINOv2 ViT model to use.
Returns:
dict: A dictionary containing:
patch_embeddings: torch.Tensor (B, N, D): The patch embeddings excluding the CLS token.
cls_embedding: torch.Tensor (B, D): The CLS token embedding.
"""
# Load model from torch hub
self.model = torch.hub.load('facebookresearch/dinov2', model_name)
# Freeze all parameters
for param in self.model.parameters():
param.requires_grad = False
def forward(self, x):
# Get features from the model's last layer
patch_embeddings, cls_token = self.model.get_intermediate_layers(x, n=1, return_class_token=True)[0] # [B, N+1, D]
# print(f" patch_embeddings shape: {patch_embeddings.shape}")
# print(f" cls_token shape: {cls_token.shape}")
# Convert patch embeddings to feature map format
B, N, D = patch_embeddings.shape
P = int(math.sqrt(N)) # -1 for cls token
feature_map = patch_embeddings.reshape(B, P, P, D).permute(0, 3, 1, 2) # [B, D, P, P]
return {
'patch_embeddings': patch_embeddings, # Per-patch embeddings excluding CLS
'embedding': cls_token, # CLS token embedding
'feature_map': feature_map,
}

View File

@ -1,69 +0,0 @@
import sys
sys.path.append("../../dinov2")
from dinov2.layers.dino_head import DINOHead # Adjust based on actual location
from .backbones import DINOv2ViT, CustomResNet
import torch.nn as nn
def get_teacher_and_student(cfg = "../config/config.yaml"):
"""
Initialize and return the teacher and student models based on the provided configuration.
This function creates the backbone and heads for both teacher and student models
as specified in the configuration. It supports different architectures for the student,
such as ResNet and DINO V2 Vision Transformer (ViT). The models are encapsulated
within `nn.ModuleDict` objects for easy integration and training.
Args:
cfg (str or dict): Path to the YAML configuration file or a dictionary
containing configuration parameters. Defaults to "../config/config.yaml".
Returns:
tuple:
- teacher (nn.ModuleDict): A module dictionary containing the teacher's backbone
and associated heads (`dino_head`, `ibot_head`).
- student (nn.ModuleDict): A module dictionary containing the student's backbone
and associated heads (`dino_head`, `ibot_head`).
Raises:
ValueError: If the student model name specified in the configuration is unsupported.
"""
student = {}
teacher = {}
# Create teacher and student backbones
teacher_backbone = DINOv2ViT(model_name=cfg["teacher"]["model_name"])
teacher['backbone'] = teacher_backbone
if cfg["student"]["model_name"].startswith('resnet'):
student_backbone = CustomResNet(model_name=cfg["student"]["model_name"], pretrained=True)
elif cfg["student"]["model_name"].startswith('dino'):
student_backbone = DINOv2ViT(model_name=cfg["student"]["model_name"])
else:
raise ValueError(f"Unsupported student model: {cfg['student']['model_name']}")
student['backbone'] = student_backbone
student["dino_head"] = DINOHead(in_dim=cfg["student"]["dino_head"]["in_dim"],
out_dim=cfg["student"]["dino_head"]["out_dim"],
hidden_dim=cfg["student"]["dino_head"]["hidden_dim"],
bottleneck_dim=cfg["student"]["dino_head"]["bottleneck_dim"])
teacher["dino_head"] = DINOHead(in_dim=cfg["teacher"]["dino_head"]["in_dim"],
out_dim=cfg["teacher"]["dino_head"]["out_dim"],
hidden_dim=cfg["teacher"]["dino_head"]["hidden_dim"],
bottleneck_dim=cfg["teacher"]["dino_head"]["bottleneck_dim"])
student["ibot_head"] = DINOHead(in_dim=cfg["student"]["ibot_head"]["in_dim"],
out_dim=cfg["student"]["ibot_head"]["out_dim"],
hidden_dim=cfg["student"]["ibot_head"]["hidden_dim"],
bottleneck_dim=cfg["student"]["ibot_head"]["bottleneck_dim"])
teacher["ibot_head"] = DINOHead(in_dim=cfg["teacher"]["ibot_head"]["in_dim"],
out_dim=cfg["teacher"]["ibot_head"]["out_dim"],
hidden_dim=cfg["teacher"]["ibot_head"]["hidden_dim"],
bottleneck_dim=cfg["teacher"]["ibot_head"]["bottleneck_dim"])
student = nn.ModuleDict(student)
teacher = nn.ModuleDict(teacher)
return teacher, student

View File

@ -1,34 +0,0 @@
import torch.nn as nn
from ...distillation_real.models.students.resnet import ResNet, BottleneckBlock, BasicStem, make_resnet_stages
# Import other student models when added
# from .efficientnet import EfficientNet, ...
class StudentModelFactory:
"""
Factory class for creating student models.
"""
@staticmethod
def create_student(model_type: str, config: dict, device: torch.device) -> nn.Module:
if model_type.lower() == 'resnet':
stem = BasicStem(
in_channels=config.get('in_channels', 3),
out_channels=64,
norm=config.get('norm_type', 'BN')
)
stages = make_resnet_stages(
depth=config.get('model_depth', 50),
block_class=config.get('block_class', BottleneckBlock),
norm=config.get('norm_type', 'BN'),
dilation=config.get('dilation', (1, 1, 1, 1))
)
student = ResNet(
stem=stem,
stages=stages,
out_features=None,
freeze_at=config.get('freeze_at', 0)
).to(device)
return student
# elif model_type.lower() == 'efficientnet':
# return EfficientNet(config).to(device)
else:
raise ValueError(f"Unsupported model type: {model_type}")

View File

@ -1,127 +0,0 @@
import torch
import torch.nn as nn
from ...distillation_real.models.students.resnet import ResNet, BasicStem, BottleneckBlock, make_resnet_stages
import numpy as np
class ModelWrapper(nn.Module):
"""
A generalizable and parametric wrapper for student models and feature matching.
Args:
model_type (str): Type of the model to initialize (e.g., 'resnet').
model_depth (int, optional): Depth of the ResNet model (default: 50).
block_class (nn.Module, optional): Block class for ResNet (default: BottleneckBlock).
norm_type (str, optional): Normalization type (default: 'BN').
in_channels (int, optional): Number of input channels (default: 3).
feature_matcher_config (dict, optional): Configuration for the feature matcher.
Should include 'in_channels', 'out_channels', 'kernel_size', and other relevant parameters.
device (torch.device, optional): Device to load the model on (default: 'cpu').
Attributes:
student (nn.Module): The student model.
feature_matcher (nn.Module): The feature matcher module.
"""
def __init__(
self,
model_type='resnet',
model_depth=50,
block_class=BottleneckBlock,
norm_type='BN',
in_channels=3,
n_patches=256,
feature_matcher_config=None,
device=torch.device('cpu')
):
super(ModelWrapper, self).__init__()
self.device = device
self.n_patches = n_patches
# Initialize the student model based on the specified type
if model_type.lower() == 'resnet':
stem = BasicStem(in_channels=in_channels, out_channels=64, norm=norm_type)
stages = make_resnet_stages(
depth=model_depth,
block_class=block_class,
norm=norm_type,
dilation=(1, 1, 1, 1)
)
self.student = ResNet(
stem=stem,
stages=stages,
out_features=None,
freeze_at=0
).to(self.device)
else:
raise ValueError(f"Unsupported model type: {model_type}")
# Initialize the feature matcher if configuration is provided
if feature_matcher_config:
self.feature_matcher = self._initialize_feature_matcher(feature_matcher_config)
else:
self.feature_matcher = None
def _initialize_feature_matcher(self, config):
"""
Initializes the feature matcher based on the provided configuration.
Args:
config (dict): Configuration dictionary for the feature matcher.
Returns:
nn.Module: Initialized feature matcher module.
"""
layers = []
in_channels = config.get('in_channels', 2048)
out_channels = config.get('out_channels', 1536)
kernel_size = config.get('kernel_size', 1)
stride = config.get('stride', 1)
padding = config.get('padding', 0)
activation = config.get('activation', None)
layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding))
if activation:
layers.append(getattr(nn, activation)())
return nn.Sequential(*layers).to(self.device)
def forward(self, x):
"""
Forward pass through the student model and feature matcher.
Args:
x (torch.Tensor): Input tensor.
Returns:
tuple: (matched_features, student_output)
"""
# Forward pass through the student model
student_output = self.student(x)
# If feature matcher is defined, process the feature maps
if self.feature_matcher and 'res5' in student_output:
res5_feature = student_output['res5']
interpolated_feature = torch.nn.functional.interpolate(
res5_feature,
size=(int(np.sqrt(self.n_patches)), int(np.sqrt(self.n_patches))),
mode='bilinear',
align_corners=False
)
matched_features = self.feature_matcher(interpolated_feature)
return matched_features, student_output
else:
return None, student_output
def get_feature_map_shape(self, x, feature_key='res5'):
"""
Utility method to get the shape of a specified feature map.
Args:
x (torch.Tensor): Input tensor.
feature_key (str, optional): Key of the feature map in student output (default: 'res5').
Returns:
torch.Size: Shape of the specified feature map.
"""
output = self.student(x)
if feature_key in output:
return output[feature_key].shape
else:
raise KeyError(f"Feature key '{feature_key}' not found in student output.")

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1,494 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/storage/disk0/arda/dinov2/distillation/../../dinov2/dinov2/layers/swiglu_ffn.py:43: UserWarning: xFormers is available (SwiGLU)\n",
" warnings.warn(\"xFormers is available (SwiGLU)\")\n",
"/storage/disk0/arda/dinov2/distillation/../../dinov2/dinov2/layers/attention.py:27: UserWarning: xFormers is available (Attention)\n",
" warnings.warn(\"xFormers is available (Attention)\")\n",
"/storage/disk0/arda/dinov2/distillation/../../dinov2/dinov2/layers/block.py:33: UserWarning: xFormers is available (Block)\n",
" warnings.warn(\"xFormers is available (Block)\")\n",
"Using cache found in /home/arda/.cache/torch/hub/facebookresearch_dinov2_main\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"layer_3 shape: torch.Size([1, 1024, 14, 14])\n",
"torch.Size([1, 1536, 16, 16])\n",
"torch.Size([1, 1536])\n",
"torch.Size([1, 2048])\n"
]
}
],
"source": [
"from models import DINOv2ViT, CustomResNet\n",
"import torch\n",
"device = torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"x = torch.randn(1, 3, 224, 224).to(device)\n",
"teacher = DINOv2ViT().to(device)\n",
"\n",
"\n",
"# out = teacher(x)\n",
"# print(out[\"patch_embeddings\"].shape)\n",
"# print(out[\"embedding\"].shape)\n",
"# print(out[\"feature_map\"].shape)\n",
"\n",
"student = CustomResNet().to(device)\n",
"out = student(x)\n",
"print(out[\"dinov2_feature_map\"].shape)\n",
"print(out[\"embedding\"].shape)\n",
"print(out[\"contrastive_embeddings\"].shape)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2d696e26ecff4b7abda04f1dc053b276",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading dataset shards: 0%| | 0/257 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e9d0656970474a3481a599b8ed19b4e6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading dataset shards: 0%| | 0/25 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5c1f7bd7cd8e422d99c4e3331582fd91",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading dataset shards: 0%| | 0/257 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b91727b023b24d7cb7dd098ddc80f464",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading dataset shards: 0%| | 0/25 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# from datasets.GTA5 import GTA5Dataset\n",
"import sys\n",
"sys.path.append('./datasets') # Add the datasets directory to the Python path\n",
"\n",
"from collate_fn import collate_data_and_cast # Adjusted import statement\n",
"# from datasets.collate_fn import collate_data_and_cast\n",
"from dinov2.data.augmentations import DataAugmentationDINO\n",
"from torch.utils.data import DataLoader\n",
"from imagenet import ImageNetDataset\n",
"\n",
"import yaml\n",
"\n",
"# Load configurations\n",
"with open(\"config/config.yaml\", \"r\") as f:\n",
" cfg = yaml.safe_load(f)\n",
"\n",
"# Data Transformation\n",
"data_transform = DataAugmentationDINO(\n",
" global_crops_scale=tuple(cfg['data_transform']['global_crops_scale']),\n",
" local_crops_scale=tuple(cfg['data_transform']['local_crops_scale']),\n",
" local_crops_number=cfg['data_transform']['n_local_crops'],\n",
" global_crops_size=tuple(cfg['data_transform']['global_crops_size']),\n",
" local_crops_size=tuple(cfg['data_transform']['local_crops_size']),\n",
")\n",
"\n",
"\n",
"# Create train and test datasets\n",
"train_dataset = ImageNetDataset(type='train', transform=data_transform, num_samples = 5000)\n",
"test_dataset = ImageNetDataset(type='test', transform=data_transform, num_samples = 500)\n",
"# Create train and test dataloaders\n",
"train_loader = DataLoader(\n",
" train_dataset,\n",
" batch_size=cfg['data_loader']['batch_size'], \n",
" num_workers=cfg['data_loader']['num_workers'],\n",
" shuffle=cfg['data_loader']['shuffle'],\n",
" collate_fn=collate_data_and_cast\n",
")\n",
"\n",
"test_loader = DataLoader(\n",
" test_dataset,\n",
" batch_size=cfg['data_loader']['batch_size'],\n",
" num_workers=cfg['data_loader']['num_workers'], \n",
" shuffle=False,\n",
" collate_fn=collate_data_and_cast\n",
")\n",
"\n",
"# Optimizer\n",
"optimizer = getattr(torch.optim, cfg['optimizer']['type'])([\n",
" {\"params\": student.parameters()},\n",
"], lr=2.5e-4)\n",
"\n",
"# Freeze teacher model\n",
"for param in teacher.parameters():\n",
" param.requires_grad = False\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/16 [00:00<?, ?it/s]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/arda/miniconda3/envs/dinov2/lib/python3.9/site-packages/xformers/ops/unbind.py:46: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
" storage_data_ptr = tensors[0].storage().data_ptr()\n",
"/home/arda/miniconda3/envs/dinov2/lib/python3.9/site-packages/xformers/ops/unbind.py:48: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
" if x.storage().data_ptr() != storage_data_ptr:\n",
"100%|██████████| 16/16 [01:01<00:00, 3.82s/it]\n",
"100%|██████████| 2/2 [00:12<00:00, 6.46s/it]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0\n",
"Train Loss: 5.8403\n",
"Train Feature Similarity: 0.0879\n",
"Train Embedding Similarity: 0.0714\n",
"Test Loss: 4.1575\n",
"Test Similarity: 0.0881\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 16/16 [01:03<00:00, 3.98s/it]\n",
"100%|██████████| 2/2 [00:12<00:00, 6.50s/it]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1\n",
"Train Loss: 5.5363\n",
"Train Feature Similarity: 0.1243\n",
"Train Embedding Similarity: 0.0848\n",
"Test Loss: 3.9246\n",
"Test Similarity: 0.1102\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 16/16 [01:05<00:00, 4.12s/it]\n",
" 50%|█████ | 1/2 [00:10<00:10, 10.78s/it]"
]
}
],
"source": [
"from tqdm import tqdm\n",
"import os\n",
"import torch.nn.functional as F # Added import for functional operations\n",
"from torch.cuda.amp import GradScaler, autocast # Import for mixed precision\n",
"\n",
"best_test_similarity = 0\n",
"save_frequency = 5\n",
"checkpoint_dir = \"./checkpoints\"\n",
"scaler = GradScaler()\n",
"\n",
"def compute_feature_similarity(feat1, feat2):\n",
" # Reshape feature maps to 2D: (batch*height*width, channels)\n",
" f1 = feat1.reshape(-1, feat1.shape[-1])\n",
" f2 = feat2.reshape(-1, feat2.shape[-1])\n",
" \n",
" # Compute cosine similarity\n",
" similarity = torch.nn.functional.cosine_similarity(f1, f2, dim=1)\n",
" return similarity.mean()\n",
"\n",
"def compound_loss(mse_loss, cosine_sim_loss, alpha=1.0, beta=1.0):\n",
" \"\"\"\n",
" Combine MSE loss and Cosine Similarity loss.\n",
" \n",
" Args:\n",
" mse_loss (torch.Tensor): Mean Squared Error loss.\n",
" cosine_sim_loss (torch.Tensor): Cosine Similarity loss.\n",
" alpha (float): Weight for MSE loss.\n",
" beta (float): Weight for Cosine Similarity loss.\n",
" \n",
" Returns:\n",
" torch.Tensor: Combined loss.\n",
" \"\"\"\n",
" return alpha * mse_loss + beta * cosine_sim_loss\n",
"\n",
"checkpoint_path = os.path.join(checkpoint_dir, \"latest_checkpoint.pth\")\n",
"\n",
"if os.path.exists(checkpoint_path):\n",
" checkpoint = torch.load(checkpoint_path)\n",
" student.load_state_dict(checkpoint['student_state_dict'])\n",
" optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
" start_epoch = checkpoint['epoch'] + 1\n",
" best_test_similarity = checkpoint['best_test_similarity']\n",
" print(f\"Resuming from epoch {start_epoch}\")\n",
"\n",
"for epoch in range(1000):\n",
" epoch_loss = []\n",
" similarities = []\n",
" embedding_similarities = []\n",
" student.train()\n",
" teacher.eval()\n",
" for i, data in enumerate(tqdm(train_loader)):\n",
" global_crops = data[\"collated_global_crops\"].to(device)\n",
" local_crops = data[\"collated_local_crops\"].to(device)\n",
"\n",
" # Mixed precision training\n",
" with autocast():\n",
" # Get feature maps from teacher\n",
" with torch.no_grad():\n",
" teacher_output = teacher(global_crops)\n",
" teacher_feature_maps = teacher_output[\"feature_map\"]\n",
" teacher_embedding = teacher_output[\"embedding\"]\n",
"\n",
" # Get feature maps from student\n",
" student_output = student(global_crops)\n",
" student_feature_maps = student_output[\"dinov2_feature_map\"]\n",
" student_embedding = student_output[\"embedding\"]\n",
"\n",
" # Calculate MSE loss between feature maps\n",
" mse_loss = torch.nn.functional.mse_loss(\n",
" student_feature_maps,\n",
" teacher_feature_maps\n",
" )\n",
" mse_embedding_loss = torch.nn.functional.mse_loss(\n",
" student_embedding,\n",
" teacher_embedding\n",
" )\n",
" \n",
" # Calculate Cosine Similarity loss\n",
" student_feature_normalized = F.normalize(student_feature_maps, p=2, dim=1)\n",
" teacher_feature_normalized = F.normalize(teacher_feature_maps, p=2, dim=1)\n",
" cosine_similarity = torch.nn.functional.cosine_similarity(\n",
" student_feature_normalized, \n",
" teacher_feature_normalized, \n",
" dim=1\n",
" )\n",
" cosine_similarity_loss = 1 - cosine_similarity.mean() # Convert similarity to loss\n",
"\n",
" student_embedding_normalized = F.normalize(student_embedding, p=2, dim=1)\n",
" teacher_embedding_normalized = F.normalize(teacher_embedding, p=2, dim=1)\n",
" cosine_similarity_embedding = torch.nn.functional.cosine_similarity(\n",
" student_embedding_normalized, \n",
" teacher_embedding_normalized, \n",
" dim=1\n",
" )\n",
" cosine_similarity_embedding_loss = 1 - cosine_similarity_embedding.mean() # Convert similarity to loss\n",
"\n",
" # Combine the losses\n",
" total_loss = compound_loss(mse_loss, cosine_similarity_loss, alpha=1.0, beta=1.0)\n",
" total_embedding_loss = compound_loss(mse_embedding_loss, cosine_similarity_embedding_loss, alpha=1.0, beta=1.0)\n",
" total_loss += total_embedding_loss\n",
" scaler.scale(total_loss).backward()\n",
" scaler.step(optimizer)\n",
" scaler.update()\n",
" optimizer.zero_grad()\n",
"\n",
" # Calculate similarity for logging\n",
" similarity = compute_feature_similarity(student_feature_maps, teacher_feature_maps)\n",
" similarities.append(similarity.item())\n",
" # Calculate embedding similarity for logging\n",
" embedding_similarity = compute_feature_similarity(student_embedding, teacher_embedding)\n",
" embedding_similarities.append(embedding_similarity.item())\n",
" \n",
" epoch_loss.append(total_loss.item())\n",
"\n",
" # Evaluation on test set\n",
" student.eval()\n",
" test_losses = []\n",
" test_similarities = []\n",
" test_embedding_similarities = []\n",
" with torch.no_grad():\n",
" for i, data in enumerate(tqdm(test_loader)):\n",
" global_crops = data[\"collated_global_crops\"].to(device)\n",
" \n",
" teacher_output = teacher(global_crops)\n",
" student_output = student(global_crops)\n",
" \n",
" # Feature map losses\n",
" test_mse = torch.nn.functional.mse_loss(\n",
" student_output[\"dinov2_feature_map\"],\n",
" teacher_output[\"feature_map\"]\n",
" )\n",
" test_similarity = compute_feature_similarity(\n",
" student_output[\"dinov2_feature_map\"],\n",
" teacher_output[\"feature_map\"]\n",
" )\n",
" \n",
" # Embedding losses\n",
" test_embedding_mse = torch.nn.functional.mse_loss(\n",
" student_output[\"embedding\"],\n",
" teacher_output[\"embedding\"]\n",
" )\n",
" test_embedding_similarity = compute_feature_similarity(\n",
" student_output[\"embedding\"],\n",
" teacher_output[\"embedding\"]\n",
" )\n",
" \n",
" test_losses.append(test_mse.item() + test_embedding_mse.item())\n",
" test_similarities.append(test_similarity.item())\n",
" test_embedding_similarities.append(test_embedding_similarity.item())\n",
"\n",
" # Calculate average metrics\n",
" avg_train_loss = sum(epoch_loss)/len(epoch_loss)\n",
" avg_train_similarity = sum(similarities)/len(similarities)\n",
" avg_train_embedding_similarity = sum(embedding_similarities)/len(embedding_similarities)\n",
" avg_test_loss = sum(test_losses)/len(test_losses)\n",
" avg_test_similarity = sum(test_similarities)/len(test_similarities)\n",
" avg_test_embedding_similarity = sum(test_embedding_similarities)/len(test_embedding_similarities)\n",
"\n",
" # Print metrics\n",
" print(f\"Epoch {epoch}\")\n",
" print(f\"Train Loss: {avg_train_loss:.4f}\")\n",
" print(f\"Train Feature Similarity: {avg_train_similarity:.4f}\")\n",
" print(f\"Train Embedding Similarity: {avg_train_embedding_similarity:.4f}\")\n",
" print(f\"Test Loss: {avg_test_loss:.4f}\")\n",
" print(f\"Test Feature Similarity: {avg_test_similarity:.4f}\")\n",
" print(f\"Test Embedding Similarity: {avg_test_embedding_similarity:.4f}\")\n",
"\n",
" # Save checkpoint\n",
" if (epoch + 1) % save_frequency == 0:\n",
" checkpoint = {\n",
" 'epoch': epoch,\n",
" 'student_state_dict': student.state_dict(),\n",
" 'optimizer_state_dict': optimizer.state_dict(),\n",
" 'train_loss': avg_train_loss,\n",
" 'test_loss': avg_test_loss,\n",
" 'train_feature_similarity': avg_train_similarity,\n",
" 'train_embedding_similarity': avg_train_embedding_similarity,\n",
" 'test_similarity': avg_test_similarity,\n",
" 'best_test_similarity': best_test_similarity\n",
" }\n",
" torch.save(checkpoint, os.path.join(checkpoint_dir, f\"checkpoint_epoch_{epoch}.pth\"))\n",
" torch.save(checkpoint, os.path.join(checkpoint_dir, \"latest_checkpoint.pth\"))\n",
" \n",
" # Save best model\n",
" if avg_test_similarity > best_test_similarity:\n",
" best_test_similarity = avg_test_similarity\n",
" torch.save({\n",
" 'epoch': epoch,\n",
" 'student_state_dict': student.state_dict(),\n",
" 'test_similarity': avg_test_similarity,\n",
" 'test_embedding_similarity': avg_test_similarity\n",
" }, os.path.join(checkpoint_dir, \"best_model.pth\"))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"labels = torch.arange(32).repeat(2) # Creates [0,1,2,...,batch_size-1, 0,1,2,...,batch_size-1]\n",
"labels = labels.to(device)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
" 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3,\n",
" 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,\n",
" 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], device='cuda:1')"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"labels"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "dinov2",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.20"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -1,957 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from datasets_gta5 import GTA5\n",
"import albumentations as A\n",
"import torch\n",
"\n",
"GTA5_PATH = '/home/arda/.cache/kagglehub/datasets/ardaerendoru/gtagta/versions/1/GTA5/GTA5'\n",
"GTA5_IMAGES = os.path.join(GTA5_PATH, 'images')\n",
"GTA5_LABELS = os.path.join(GTA5_PATH, 'labels')\n",
"\n",
"device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"\n",
"transform = A.Compose([\n",
" A.Resize(512, 1024)\n",
"])\n",
"GTA5_dataset = GTA5(GTA5_path=GTA5_PATH, transform=transform)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/storage/disk0/arda/dinov2/distillation/../../dinov2/dinov2/layers/swiglu_ffn.py:43: UserWarning: xFormers is available (SwiGLU)\n",
" warnings.warn(\"xFormers is available (SwiGLU)\")\n",
"/storage/disk0/arda/dinov2/distillation/../../dinov2/dinov2/layers/attention.py:27: UserWarning: xFormers is available (Attention)\n",
" warnings.warn(\"xFormers is available (Attention)\")\n",
"/storage/disk0/arda/dinov2/distillation/../../dinov2/dinov2/layers/block.py:33: UserWarning: xFormers is available (Block)\n",
" warnings.warn(\"xFormers is available (Block)\")\n"
]
},
{
"data": {
"text/plain": [
"CustomResNet(\n",
" (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
" (layer1): Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" )\n",
" (layer2): Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" (3): Bottleneck(\n",
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" )\n",
" (layer3): Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" (3): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" (4): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" (5): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" )\n",
" (layer4): Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" )\n",
" (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
" (feature_matcher): Conv2d(2048, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
")"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from models import CustomResNet\n",
"encoder = CustomResNet()\n",
"# encoder.load_state_dict(torch.load('student_backbone_checkpoint.pth')['backbone_state_dict'])\n",
"\n",
"encoder.eval()\n",
"encoder.to(device)\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"asd = torch.randn(1, 3, 512, 1024).to(device)\n",
"encoder(asd)[\"feature_map\"].shape\n",
"\n",
"# Freeze all parameters of the encoder\n",
"for param in encoder.parameters():\n",
" param.requires_grad = True\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 625/625 [05:04<00:00, 2.05it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 1/10\n",
"Loss: 1.0042\n",
"Pixel Accuracy: 0.8131\n",
"Mean Class Accuracy: 0.3035\n",
"Mean IoU: 0.2119\n",
"\n",
"Per-class metrics:\n",
"Class 0 - Acc: 0.9479, IoU: 0.9112\n",
"Class 1 - Acc: 0.4739, IoU: 0.3257\n",
"Class 2 - Acc: 0.6635, IoU: 0.5876\n",
"Class 3 - Acc: 0.3823, IoU: 0.0700\n",
"Class 4 - Acc: 0.0076, IoU: 0.0012\n",
"Class 5 - Acc: 0.0200, IoU: 0.0007\n",
"Class 6 - Acc: 0.0012, IoU: 0.0005\n",
"Class 7 - Acc: 0.0030, IoU: 0.0016\n",
"Class 8 - Acc: 0.6669, IoU: 0.4891\n",
"Class 9 - Acc: 0.6008, IoU: 0.2622\n",
"Class 10 - Acc: 0.8642, IoU: 0.8031\n",
"Class 11 - Acc: 0.0023, IoU: 0.0003\n",
"Class 12 - Acc: 0.0005, IoU: 0.0000\n",
"Class 13 - Acc: 0.5293, IoU: 0.4233\n",
"Class 14 - Acc: 0.5114, IoU: 0.1243\n",
"Class 15 - Acc: 0.0106, IoU: 0.0013\n",
"Class 16 - Acc: 0.0783, IoU: 0.0213\n",
"Class 17 - Acc: 0.0027, IoU: 0.0020\n",
"Class 18 - Acc: 0.0001, IoU: 0.0001\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 625/625 [05:02<00:00, 2.06it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 2/10\n",
"Loss: 0.5420\n",
"Pixel Accuracy: 0.8578\n",
"Mean Class Accuracy: 0.3853\n",
"Mean IoU: 0.2714\n",
"\n",
"Per-class metrics:\n",
"Class 0 - Acc: 0.9575, IoU: 0.9333\n",
"Class 1 - Acc: 0.6296, IoU: 0.4457\n",
"Class 2 - Acc: 0.7383, IoU: 0.6627\n",
"Class 3 - Acc: 0.5219, IoU: 0.2756\n",
"Class 4 - Acc: 0.0285, IoU: 0.0000\n",
"Class 5 - Acc: 0.5090, IoU: 0.0061\n",
"Class 6 - Acc: 0.0000, IoU: 0.0000\n",
"Class 7 - Acc: 0.0207, IoU: 0.0028\n",
"Class 8 - Acc: 0.7244, IoU: 0.5757\n",
"Class 9 - Acc: 0.6691, IoU: 0.4452\n",
"Class 10 - Acc: 0.8985, IoU: 0.8491\n",
"Class 11 - Acc: 0.0000, IoU: 0.0000\n",
"Class 12 - Acc: 0.0000, IoU: 0.0000\n",
"Class 13 - Acc: 0.7140, IoU: 0.5871\n",
"Class 14 - Acc: 0.5454, IoU: 0.3653\n",
"Class 15 - Acc: 0.0000, IoU: 0.0000\n",
"Class 16 - Acc: 0.2911, IoU: 0.0080\n",
"Class 17 - Acc: 0.0736, IoU: 0.0001\n",
"Class 18 - Acc: 0.0000, IoU: 0.0000\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 625/625 [05:09<00:00, 2.02it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 3/10\n",
"Loss: 0.4474\n",
"Pixel Accuracy: 0.8694\n",
"Mean Class Accuracy: 0.4404\n",
"Mean IoU: 0.2925\n",
"\n",
"Per-class metrics:\n",
"Class 0 - Acc: 0.9605, IoU: 0.9393\n",
"Class 1 - Acc: 0.6845, IoU: 0.4889\n",
"Class 2 - Acc: 0.7617, IoU: 0.6874\n",
"Class 3 - Acc: 0.5666, IoU: 0.3313\n",
"Class 4 - Acc: 0.5337, IoU: 0.0441\n",
"Class 5 - Acc: 0.4405, IoU: 0.0674\n",
"Class 6 - Acc: 0.0000, IoU: 0.0000\n",
"Class 7 - Acc: 0.0418, IoU: 0.0063\n",
"Class 8 - Acc: 0.7413, IoU: 0.5980\n",
"Class 9 - Acc: 0.7017, IoU: 0.4898\n",
"Class 10 - Acc: 0.9103, IoU: 0.8633\n",
"Class 11 - Acc: 0.0000, IoU: 0.0000\n",
"Class 12 - Acc: 0.0000, IoU: 0.0000\n",
"Class 13 - Acc: 0.7591, IoU: 0.6366\n",
"Class 14 - Acc: 0.5490, IoU: 0.3969\n",
"Class 15 - Acc: 0.0137, IoU: 0.0000\n",
"Class 16 - Acc: 0.1809, IoU: 0.0080\n",
"Class 17 - Acc: 0.5217, IoU: 0.0001\n",
"Class 18 - Acc: 0.0000, IoU: 0.0000\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 625/625 [05:21<00:00, 1.94it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 4/10\n",
"Loss: 0.3998\n",
"Pixel Accuracy: 0.8774\n",
"Mean Class Accuracy: 0.4623\n",
"Mean IoU: 0.3114\n",
"\n",
"Per-class metrics:\n",
"Class 0 - Acc: 0.9630, IoU: 0.9424\n",
"Class 1 - Acc: 0.6986, IoU: 0.5060\n",
"Class 2 - Acc: 0.7858, IoU: 0.7109\n",
"Class 3 - Acc: 0.6149, IoU: 0.3757\n",
"Class 4 - Acc: 0.5206, IoU: 0.1590\n",
"Class 5 - Acc: 0.4578, IoU: 0.1030\n",
"Class 6 - Acc: 0.0000, IoU: 0.0000\n",
"Class 7 - Acc: 0.0656, IoU: 0.0054\n",
"Class 8 - Acc: 0.7545, IoU: 0.6145\n",
"Class 9 - Acc: 0.7230, IoU: 0.5221\n",
"Class 10 - Acc: 0.9161, IoU: 0.8715\n",
"Class 11 - Acc: 0.0000, IoU: 0.0000\n",
"Class 12 - Acc: 0.0000, IoU: 0.0000\n",
"Class 13 - Acc: 0.7828, IoU: 0.6665\n",
"Class 14 - Acc: 0.5560, IoU: 0.4277\n",
"Class 15 - Acc: 0.7750, IoU: 0.0002\n",
"Class 16 - Acc: 0.1692, IoU: 0.0113\n",
"Class 17 - Acc: 0.0000, IoU: 0.0000\n",
"Class 18 - Acc: 0.0000, IoU: 0.0000\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 625/625 [05:18<00:00, 1.97it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 5/10\n",
"Loss: 0.3691\n",
"Pixel Accuracy: 0.8840\n",
"Mean Class Accuracy: 0.5277\n",
"Mean IoU: 0.3295\n",
"\n",
"Per-class metrics:\n",
"Class 0 - Acc: 0.9657, IoU: 0.9464\n",
"Class 1 - Acc: 0.7189, IoU: 0.5318\n",
"Class 2 - Acc: 0.7985, IoU: 0.7254\n",
"Class 3 - Acc: 0.6484, IoU: 0.4077\n",
"Class 4 - Acc: 0.5407, IoU: 0.2147\n",
"Class 5 - Acc: 0.4760, IoU: 0.1247\n",
"Class 6 - Acc: 0.0000, IoU: 0.0000\n",
"Class 7 - Acc: 0.1132, IoU: 0.0053\n",
"Class 8 - Acc: 0.7630, IoU: 0.6253\n",
"Class 9 - Acc: 0.7370, IoU: 0.5430\n",
"Class 10 - Acc: 0.9210, IoU: 0.8773\n",
"Class 11 - Acc: 1.0000, IoU: 0.0000\n",
"Class 12 - Acc: 0.0000, IoU: 0.0000\n",
"Class 13 - Acc: 0.7987, IoU: 0.6872\n",
"Class 14 - Acc: 0.6004, IoU: 0.4685\n",
"Class 15 - Acc: 0.5672, IoU: 0.0035\n",
"Class 16 - Acc: 0.3776, IoU: 0.0990\n",
"Class 17 - Acc: 0.0000, IoU: 0.0000\n",
"Class 18 - Acc: 0.0000, IoU: 0.0000\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 625/625 [05:21<00:00, 1.95it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 6/10\n",
"Loss: 0.3422\n",
"Pixel Accuracy: 0.8916\n",
"Mean Class Accuracy: 0.5649\n",
"Mean IoU: 0.3674\n",
"\n",
"Per-class metrics:\n",
"Class 0 - Acc: 0.9678, IoU: 0.9492\n",
"Class 1 - Acc: 0.7343, IoU: 0.5554\n",
"Class 2 - Acc: 0.8114, IoU: 0.7392\n",
"Class 3 - Acc: 0.6759, IoU: 0.4458\n",
"Class 4 - Acc: 0.5666, IoU: 0.2546\n",
"Class 5 - Acc: 0.4890, IoU: 0.1428\n",
"Class 6 - Acc: 0.0053, IoU: 0.0000\n",
"Class 7 - Acc: 0.3257, IoU: 0.0115\n",
"Class 8 - Acc: 0.7729, IoU: 0.6382\n",
"Class 9 - Acc: 0.7548, IoU: 0.5681\n",
"Class 10 - Acc: 0.9247, IoU: 0.8828\n",
"Class 11 - Acc: 0.8371, IoU: 0.0013\n",
"Class 12 - Acc: 0.0000, IoU: 0.0000\n",
"Class 13 - Acc: 0.8124, IoU: 0.7043\n",
"Class 14 - Acc: 0.7332, IoU: 0.5560\n",
"Class 15 - Acc: 0.7383, IoU: 0.1539\n",
"Class 16 - Acc: 0.5842, IoU: 0.3779\n",
"Class 17 - Acc: 0.0000, IoU: 0.0000\n",
"Class 18 - Acc: 0.0000, IoU: 0.0000\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 625/625 [05:01<00:00, 2.07it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 7/10\n",
"Loss: 0.3161\n",
"Pixel Accuracy: 0.8988\n",
"Mean Class Accuracy: 0.6234\n",
"Mean IoU: 0.4084\n",
"\n",
"Per-class metrics:\n",
"Class 0 - Acc: 0.9703, IoU: 0.9527\n",
"Class 1 - Acc: 0.7466, IoU: 0.5767\n",
"Class 2 - Acc: 0.8263, IoU: 0.7558\n",
"Class 3 - Acc: 0.7073, IoU: 0.4823\n",
"Class 4 - Acc: 0.6031, IoU: 0.3065\n",
"Class 5 - Acc: 0.5035, IoU: 0.1583\n",
"Class 6 - Acc: 0.5900, IoU: 0.0231\n",
"Class 7 - Acc: 0.7235, IoU: 0.1230\n",
"Class 8 - Acc: 0.7802, IoU: 0.6478\n",
"Class 9 - Acc: 0.7666, IoU: 0.5857\n",
"Class 10 - Acc: 0.9281, IoU: 0.8875\n",
"Class 11 - Acc: 0.6169, IoU: 0.0523\n",
"Class 12 - Acc: 0.0000, IoU: 0.0000\n",
"Class 13 - Acc: 0.8327, IoU: 0.7291\n",
"Class 14 - Acc: 0.8073, IoU: 0.6301\n",
"Class 15 - Acc: 0.7692, IoU: 0.3528\n",
"Class 16 - Acc: 0.6729, IoU: 0.4956\n",
"Class 17 - Acc: 0.0000, IoU: 0.0000\n",
"Class 18 - Acc: 0.0000, IoU: 0.0000\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 625/625 [05:01<00:00, 2.07it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 8/10\n",
"Loss: 0.2998\n",
"Pixel Accuracy: 0.9030\n",
"Mean Class Accuracy: 0.6303\n",
"Mean IoU: 0.4445\n",
"\n",
"Per-class metrics:\n",
"Class 0 - Acc: 0.9711, IoU: 0.9539\n",
"Class 1 - Acc: 0.7571, IoU: 0.5845\n",
"Class 2 - Acc: 0.8360, IoU: 0.7655\n",
"Class 3 - Acc: 0.7286, IoU: 0.5109\n",
"Class 4 - Acc: 0.6284, IoU: 0.3354\n",
"Class 5 - Acc: 0.5278, IoU: 0.1734\n",
"Class 6 - Acc: 0.5810, IoU: 0.0764\n",
"Class 7 - Acc: 0.6968, IoU: 0.2643\n",
"Class 8 - Acc: 0.7860, IoU: 0.6541\n",
"Class 9 - Acc: 0.7736, IoU: 0.6015\n",
"Class 10 - Acc: 0.9303, IoU: 0.8904\n",
"Class 11 - Acc: 0.5340, IoU: 0.1885\n",
"Class 12 - Acc: 0.0000, IoU: 0.0000\n",
"Class 13 - Acc: 0.8442, IoU: 0.7447\n",
"Class 14 - Acc: 0.8279, IoU: 0.6598\n",
"Class 15 - Acc: 0.8079, IoU: 0.4820\n",
"Class 16 - Acc: 0.7457, IoU: 0.5606\n",
"Class 17 - Acc: 0.0000, IoU: 0.0000\n",
"Class 18 - Acc: 0.0000, IoU: 0.0000\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 625/625 [05:06<00:00, 2.04it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 9/10\n",
"Loss: 0.2808\n",
"Pixel Accuracy: 0.9083\n",
"Mean Class Accuracy: 0.6407\n",
"Mean IoU: 0.4691\n",
"\n",
"Per-class metrics:\n",
"Class 0 - Acc: 0.9735, IoU: 0.9573\n",
"Class 1 - Acc: 0.7711, IoU: 0.6084\n",
"Class 2 - Acc: 0.8466, IoU: 0.7781\n",
"Class 3 - Acc: 0.7457, IoU: 0.5347\n",
"Class 4 - Acc: 0.6598, IoU: 0.3831\n",
"Class 5 - Acc: 0.5488, IoU: 0.1873\n",
"Class 6 - Acc: 0.5466, IoU: 0.1055\n",
"Class 7 - Acc: 0.6838, IoU: 0.3030\n",
"Class 8 - Acc: 0.7932, IoU: 0.6649\n",
"Class 9 - Acc: 0.7872, IoU: 0.6204\n",
"Class 10 - Acc: 0.9330, IoU: 0.8935\n",
"Class 11 - Acc: 0.5515, IoU: 0.2404\n",
"Class 12 - Acc: 0.0000, IoU: 0.0000\n",
"Class 13 - Acc: 0.8499, IoU: 0.7519\n",
"Class 14 - Acc: 0.8496, IoU: 0.6852\n",
"Class 15 - Acc: 0.8560, IoU: 0.5703\n",
"Class 16 - Acc: 0.7770, IoU: 0.6282\n",
"Class 17 - Acc: 0.0000, IoU: 0.0000\n",
"Class 18 - Acc: 0.0000, IoU: 0.0000\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 625/625 [05:02<00:00, 2.07it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 10/10\n",
"Loss: 0.2632\n",
"Pixel Accuracy: 0.9133\n",
"Mean Class Accuracy: 0.6571\n",
"Mean IoU: 0.4930\n",
"\n",
"Per-class metrics:\n",
"Class 0 - Acc: 0.9747, IoU: 0.9590\n",
"Class 1 - Acc: 0.7830, IoU: 0.6237\n",
"Class 2 - Acc: 0.8578, IoU: 0.7914\n",
"Class 3 - Acc: 0.7660, IoU: 0.5668\n",
"Class 4 - Acc: 0.6851, IoU: 0.4189\n",
"Class 5 - Acc: 0.5689, IoU: 0.2065\n",
"Class 6 - Acc: 0.5763, IoU: 0.1325\n",
"Class 7 - Acc: 0.7205, IoU: 0.3465\n",
"Class 8 - Acc: 0.7998, IoU: 0.6733\n",
"Class 9 - Acc: 0.7946, IoU: 0.6361\n",
"Class 10 - Acc: 0.9361, IoU: 0.8980\n",
"Class 11 - Acc: 0.5662, IoU: 0.2605\n",
"Class 12 - Acc: 0.0000, IoU: 0.0000\n",
"Class 13 - Acc: 0.8638, IoU: 0.7721\n",
"Class 14 - Acc: 0.8702, IoU: 0.7258\n",
"Class 15 - Acc: 0.8836, IoU: 0.6627\n",
"Class 16 - Acc: 0.8378, IoU: 0.6929\n",
"Class 17 - Acc: 0.0000, IoU: 0.0000\n",
"Class 18 - Acc: 0.0000, IoU: 0.0000\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# First, let's create a simple decoder network\n",
"import numpy as np\n",
"import tqdm as tqdm\n",
"class SegmentationDecoder(torch.nn.Module):\n",
" def __init__(self, in_channels=2048, num_classes=19):\n",
" super().__init__()\n",
" self.decoder = torch.nn.Sequential(\n",
" # 16x32 -> 32x64\n",
" torch.nn.ConvTranspose2d(in_channels, 1024, kernel_size=4, stride=2, padding=1),\n",
" torch.nn.BatchNorm2d(1024),\n",
" torch.nn.ReLU(),\n",
" \n",
" # 32x64 -> 64x128\n",
" torch.nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1),\n",
" torch.nn.BatchNorm2d(512),\n",
" torch.nn.ReLU(),\n",
" \n",
" # 64x128 -> 128x256\n",
" torch.nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),\n",
" torch.nn.BatchNorm2d(256),\n",
" torch.nn.ReLU(),\n",
" \n",
" # 128x256 -> 256x512\n",
" torch.nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),\n",
" torch.nn.BatchNorm2d(128),\n",
" torch.nn.ReLU(),\n",
" \n",
" # 256x512 -> 512x1024\n",
" torch.nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),\n",
" torch.nn.BatchNorm2d(64),\n",
" torch.nn.ReLU(),\n",
" \n",
" # Final 1x1 conv to get to num_classes\n",
" torch.nn.Conv2d(64, num_classes, kernel_size=1)\n",
" )\n",
" def forward(self, x):\n",
" x = self.decoder(x)\n",
" # Ensure exact output size\n",
" if x.shape[-2:] != (512, 1024):\n",
" x = torch.nn.functional.interpolate(\n",
" x, size=(512, 1024), \n",
" mode='bilinear', \n",
" align_corners=False\n",
" )\n",
" return x\n",
"\n",
"# Initialize decoder, optimizer, and loss function\n",
"decoder = SegmentationDecoder().to(device)\n",
"optimizer = torch.optim.Adam(decoder.parameters(), lr=1e-4)\n",
"criterion = torch.nn.CrossEntropyLoss(ignore_index=255)\n",
"\n",
"def fast_hist(a: np.ndarray, b: np.ndarray, n: int) -> np.ndarray:\n",
" k = (b >= 0) & (b < n)\n",
" return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)\n",
"\n",
"def per_class_iou(hist: np.ndarray) -> np.ndarray:\n",
" epsilon = 1e-5\n",
" return (np.diag(hist)) / (hist.sum(1) + hist.sum(0) - np.diag(hist) + epsilon)\n",
"\n",
"def train_epoch(encoder, decoder, dataloader, optimizer, criterion, device, num_classes=19):\n",
" decoder.train()\n",
" encoder.train() # Keep DINO frozen\n",
" \n",
" total_loss = 0\n",
" hist = np.zeros((num_classes, num_classes)) # Single histogram for entire epoch\n",
" total_pixels = 0\n",
" correct_pixels = 0\n",
" \n",
" for images, labels in tqdm.tqdm(dataloader):\n",
" images = images.to(device)\n",
" labels = labels.to(device)\n",
" \n",
" # Get DINO features\n",
" # with torch.no_grad():\n",
" features = encoder(images)[\"feature_map\"]\n",
" \n",
" # Forward pass through decoder\n",
" outputs = decoder(features)\n",
" \n",
" # Resize outputs to match label size if needed\n",
" if outputs.shape[-2:] != labels.shape[-2:]:\n",
" outputs = torch.nn.functional.interpolate(\n",
" outputs, size=labels.shape[-2:], mode='bilinear', align_corners=False)\n",
" \n",
" # Calculate loss\n",
" loss = criterion(outputs, labels)\n",
" \n",
" # Backward pass\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" total_loss += loss.item()\n",
" \n",
" # Calculate metrics\n",
" preds = torch.argmax(torch.softmax(outputs, dim=1), dim=1)\n",
" \n",
" # Pixel Accuracy\n",
" valid_mask = labels != 255 # Ignore index\n",
" total_pixels += valid_mask.sum().item()\n",
" correct_pixels += ((preds == labels) & valid_mask).sum().item()\n",
" \n",
" # IoU\n",
" preds = preds.cpu().numpy()\n",
" target = labels.cpu().numpy()\n",
" hist += fast_hist(preds.flatten(), target.flatten(), num_classes)\n",
" \n",
" # Calculate final metrics\n",
" pixel_acc = correct_pixels / total_pixels\n",
" \n",
" # Per-class accuracy (mean class accuracy)\n",
" class_acc = np.diag(hist) / (hist.sum(1) + np.finfo(np.float32).eps)\n",
" mean_class_acc = np.nanmean(class_acc)\n",
" \n",
" # IoU metrics\n",
" iou = per_class_iou(hist)\n",
" mean_iou = np.nanmean(iou)\n",
" \n",
" metrics = {\n",
" 'loss': total_loss / len(dataloader),\n",
" 'pixel_acc': pixel_acc,\n",
" 'mean_class_acc': mean_class_acc,\n",
" 'mean_iou': mean_iou,\n",
" 'class_iou': iou,\n",
" 'class_acc': class_acc\n",
" }\n",
" \n",
" return metrics\n",
"\n",
"train_loader = torch.utils.data.DataLoader(\n",
" GTA5_dataset, \n",
" batch_size=4,\n",
" shuffle=True,\n",
" num_workers=4\n",
")\n",
"# Training loop\n",
"num_epochs = 10\n",
"for epoch in range(num_epochs):\n",
" metrics = train_epoch(encoder, decoder, train_loader, optimizer, criterion, device)\n",
" \n",
" print(f\"\\nEpoch {epoch+1}/{num_epochs}\")\n",
" print(f\"Loss: {metrics['loss']:.4f}\")\n",
" print(f\"Pixel Accuracy: {metrics['pixel_acc']:.4f}\")\n",
" print(f\"Mean Class Accuracy: {metrics['mean_class_acc']:.4f}\")\n",
" print(f\"Mean IoU: {metrics['mean_iou']:.4f}\")\n",
" \n",
" # Optionally print per-class metrics\n",
" print(\"\\nPer-class metrics:\")\n",
" for i in range(19): # Assuming 19 classes\n",
" print(f\"Class {i:2d} - Acc: {metrics['class_acc'][i]:.4f}, IoU: {metrics['class_iou'][i]:.4f}\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# Epoch 3/10\n",
"# Loss: 1.2027\n",
"# Pixel Accuracy: 0.6440\n",
"# Mean Class Accuracy: 0.1213\n",
"# Mean IoU: 0.0818\n",
"\n",
"# Per-class metrics:\n",
"# Class 0 - Acc: 0.6736, IoU: 0.6670\n",
"# Class 1 - Acc: 0.1000, IoU: 0.0000\n",
"# Class 2 - Acc: 0.3823, IoU: 0.1999\n",
"# Class 3 - Acc: 0.0000, IoU: 0.0000\n",
"# Class 4 - Acc: 0.0000, IoU: 0.0000\n",
"# Class 5 - Acc: 0.0000, IoU: 0.0000\n",
"# Class 6 - Acc: 0.0000, IoU: 0.0000\n",
"# Class 7 - Acc: 0.0000, IoU: 0.0000\n",
"# Class 8 - Acc: 0.4519, IoU: 0.0889\n",
"# Class 9 - Acc: 0.0000, IoU: 0.0000\n",
"# Class 10 - Acc: 0.6970, IoU: 0.5976\n",
"# Class 11 - Acc: 0.0000, IoU: 0.0000\n",
"# Class 12 - Acc: 0.0000, IoU: 0.0000\n",
"# Class 13 - Acc: 0.0000, IoU: 0.0000\n",
"# Class 14 - Acc: 0.0000, IoU: 0.0000\n",
"# Class 15 - Acc: 0.0000, IoU: 0.0000\n",
"# Class 16 - Acc: 0.0000, IoU: 0.0000\n",
"# Class 17 - Acc: 0.0000, IoU: 0.0000\n",
"# Class 18 - Acc: 0.0000, IoU: 0.0000"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"ename": "ImportError",
"evalue": "cannot import name 'get_id_to_color' from 'datasets' (unknown location)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[7], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpyplot\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mplt\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mdatasets\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m get_id_to_color \n\u001b[1;32m 5\u001b[0m img_idx \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m310\u001b[39m\n\u001b[1;32m 7\u001b[0m embeddings \u001b[38;5;241m=\u001b[39m encoder\u001b[38;5;241m.\u001b[39mget_intermediate_layers(GTA5_dataset[img_idx][\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mto(device), n\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m, reshape\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, return_class_token\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, norm\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)[\u001b[38;5;241m0\u001b[39m]\n",
"\u001b[0;31mImportError\u001b[0m: cannot import name 'get_id_to_color' from 'datasets' (unknown location)"
]
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"from datasets import get_id_to_color \n",
"\n",
"\n",
"img_idx = 310\n",
"\n",
"embeddings = encoder.get_intermediate_layers(GTA5_dataset[img_idx][0].unsqueeze(0).to(device), n=1, reshape=True, return_class_token=False, norm=False)[0]\n",
"out = decoder(embeddings)\n",
"id_to_color = get_id_to_color()\n",
"\n",
"pred = out.argmax(1).cpu().numpy()\n",
"pred = pred.reshape(518, 1036)\n",
"# Convert class IDs to RGB colors\n",
"color_map = np.array([id_to_color.get(i, (0, 0, 0)) for i in range(max(id_to_color.keys()) + 1)])\n",
"pred_rgb = color_map[pred]\n",
"\n",
"plt.figure(figsize=(10, 10))\n",
"plt.imshow(pred_rgb)\n",
"\n",
"plt.figure(figsize=(10, 10))\n",
"labels = GTA5_dataset[img_idx][1].cpu().numpy()\n",
"color_map = np.array([id_to_color.get(i, (0, 0, 0)) for i in range(max(id_to_color.keys()) + 1)])\n",
"pred_rgb = np.zeros((*labels.shape, 3), dtype=np.uint8)\n",
"mask = labels < len(color_map)\n",
"pred_rgb[mask] = color_map[labels[mask]]\n",
"plt.imshow(pred_rgb)\n"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([518, 1036])"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"GTA5_dataset[img_idx][1].shape\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "dinov2",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.20"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

File diff suppressed because one or more lines are too long

View File

@ -1,70 +0,0 @@
INFO:root:Epoch 1
INFO:root:Train Loss: 1.5444
INFO:root:Train Similarity: 0.3211
INFO:root:Test Loss: 1.4806
INFO:root:Test Similarity: 0.3588
INFO:root:Epoch 2
INFO:root:Train Loss: 1.4716
INFO:root:Train Similarity: 0.3646
INFO:root:Test Loss: 1.4506
INFO:root:Test Similarity: 0.3767
INFO:root:Epoch 3
INFO:root:Train Loss: 1.4467
INFO:root:Train Similarity: 0.3790
INFO:root:Test Loss: 1.4322
INFO:root:Test Similarity: 0.3869
INFO:root:Epoch 4
INFO:root:Train Loss: 1.4311
INFO:root:Train Similarity: 0.3878
INFO:root:Test Loss: 1.4236
INFO:root:Test Similarity: 0.3920
INFO:root:Epoch 5
INFO:root:Train Loss: 1.4195
INFO:root:Train Similarity: 0.3941
INFO:root:Test Loss: 1.4170
INFO:root:Test Similarity: 0.3956
INFO:root:Epoch 6
INFO:root:Train Loss: 1.4102
INFO:root:Train Similarity: 0.3993
INFO:root:Test Loss: 1.4017
INFO:root:Test Similarity: 0.4038
INFO:root:Epoch 7
INFO:root:Train Loss: 1.4021
INFO:root:Train Similarity: 0.4037
INFO:root:Test Loss: 1.3943
INFO:root:Test Similarity: 0.4079
INFO:root:Epoch 8
INFO:root:Train Loss: 1.3956
INFO:root:Train Similarity: 0.4072
INFO:root:Test Loss: 1.3929
INFO:root:Test Similarity: 0.4091
INFO:root:Epoch 9
INFO:root:Train Loss: 1.3894
INFO:root:Train Similarity: 0.4106
INFO:root:Test Loss: 1.3853
INFO:root:Test Similarity: 0.4128
INFO:root:Epoch 10
INFO:root:Train Loss: 1.3839
INFO:root:Train Similarity: 0.4136
INFO:root:Test Loss: 1.3811
INFO:root:Test Similarity: 0.4154
INFO:root:Epoch 11
INFO:root:Train Loss: 1.3789
INFO:root:Train Similarity: 0.4162
INFO:root:Test Loss: 1.3791
INFO:root:Test Similarity: 0.4163
INFO:root:Epoch 12
INFO:root:Train Loss: 1.3740
INFO:root:Train Similarity: 0.4189
INFO:root:Test Loss: 1.3732
INFO:root:Test Similarity: 0.4193
INFO:root:Epoch 13
INFO:root:Train Loss: 1.3725
INFO:root:Train Similarity: 0.4197
INFO:root:Test Loss: 1.3592
INFO:root:Test Similarity: 0.4268
INFO:root:Epoch 14
INFO:root:Train Loss: 1.3692
INFO:root:Train Similarity: 0.4215
INFO:root:Test Loss: 1.3625
INFO:root:Test Similarity: 0.4264

View File

@ -1,40 +0,0 @@
import torch
import torch.nn as nn
def match_vit_features(feature_map, patch_size, height, width):
"""
Match the feature map to the ViT patch grid.
Args:
feature_map (torch.Tensor) (B, C, H, W): The feature map to match.
patch_size (int): The patch size of the ViT model.
height (int): The height of the image.
width (int): The width of the image.
Returns:
patches: torch.Tensor (B, N, C): The feature map matched to the ViT patch grid.
"""
B, C, H, W = feature_map.shape
H_patch = height // patch_size
W_patch = width // patch_size
# Interpolate to match ViT patch grid
feature_map = nn.functional.interpolate(
feature_map,
size=(H_patch, W_patch),
mode='bilinear',
align_corners=False
)
# Reshape into patches and project to match ViT dimension
patches = feature_map.permute(0, 2, 3, 1) # [B, H, W, C]
patches = patches.reshape(B, H_patch * W_patch, C) # [B, N, C]
return patches
def match_feature_map(feature_map, patch_size, height, width):
B, C, H, W = feature_map.shape
H_patch = height // patch_size
W_patch = width // patch_size
return feature_map.reshape(B, C, H_patch, W_patch)