added input, model, merge block for raw

pull/511/head
Veronikkkka 2025-03-05 09:00:59 +00:00
parent e1277af2ba
commit 9c1bde505b
7 changed files with 1100 additions and 13 deletions

View File

@ -0,0 +1,98 @@
# base config: dinov2/configs/train/vitl14.yaml
compute_precision:
grad_scaler: true
teacher:
backbone:
sharding_strategy: SHARD_GRAD_OP
mixed_precision:
param_dtype: fp16
reduce_dtype: fp16
buffer_dtype: fp32
dino_head:
sharding_strategy: SHARD_GRAD_OP
mixed_precision:
param_dtype: fp16
reduce_dtype: fp16
buffer_dtype: fp32
ibot_head:
sharding_strategy: SHARD_GRAD_OP
mixed_precision:
param_dtype: fp16
reduce_dtype: fp16
buffer_dtype: fp32
student:
backbone:
sharding_strategy: SHARD_GRAD_OP
mixed_precision:
param_dtype: fp16
reduce_dtype: fp16
buffer_dtype: fp32
dino_head:
sharding_strategy: SHARD_GRAD_OP
mixed_precision:
param_dtype: fp16
reduce_dtype: fp32
buffer_dtype: fp32
ibot_head:
sharding_strategy: SHARD_GRAD_OP
mixed_precision:
param_dtype: fp16
reduce_dtype: fp32
buffer_dtype: fp32
ibot:
separate_head: true
#head_n_prototypes: 131072
data_transform: "default"
train:
batch_size_per_gpu: 16 #vitg 26+, vitl: 56, vits:152, vitb:120 for 8 node
num_workers: 1
OFFICIAL_EPOCH_LENGTH: 100 # 1250
dataset_path: ImageNet:root=/home/paperspace/Documents/nika_space/ADE20K/ADEChallengeData2016/images/training_raw/
centering: sinkhorn_knopp
drop_path_rate: 0.4
ffn_layer: swiglufused
block_chunks: 0 # for distributed training
num_register_tokens: 0 # 0 for no register tokens
teacher:
momentum_teacher: 0.994
optim:
epochs: 20 # 500
weight_decay_end: 0.2
base_lr: 0.001 # learning rate for a batch size of 1024
warmup_epochs: 20 # 80
layerwise_decay: 1.0
evaluation:
eval_period_iterations: 1000
# adapt to model architecture
# ---------------------------
# config for vit
# "dinov2_vits14","dinov2_vitb14","dinov2_vitl14","dinov2_vitg14"
student:
arch: vit_base
patch_size: 14
crops:
global_crops_scale:
- 0.32 #0.32 default
- 1.0
local_crops_size: 98
local_crops_number: 1 #!!! bit hacky, 1 indicates NO LOCAL CROPS !!!
dino:
head_bottleneck_dim: 256 #vits: 256, vitl: 384
smooth_rank_loss_weight: 0.0 #doesnt help
# ---------------------------
# config for vim_tiny
#student:
# arch: vim_tiny
# patch_size: 16
#crops:
# local_crops_size: 96
#dino:
# head_bottleneck_dim: 256

View File

@ -5,3 +5,4 @@
from .image_net import ImageNet
from .image_net_22k import ImageNet22k
from .my_dataset import ADK20Dataset

View File

@ -0,0 +1,117 @@
import os
from pathlib import Path
from typing import Callable, Optional, Tuple, List
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image
from dinov2.train.rgb_to_raw import rgb_to_raw, raw_to_rgb
class ADK20Dataset(Dataset):
def __init__(
self,
root: str,
transforms: Optional[Callable] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
shuffle: bool = False,
) -> None:
"""
ADK20 Dataset for image classification.
Args:
root (str): Path to dataset directory.
transforms (Callable, optional): Combined image and target transformations.
transform (Callable, optional): Image transformations.
target_transform (Callable, optional): Target transformations.
shuffle (bool, optional): If True, shuffles the dataset. Defaults to False.
"""
self.root = Path(root)
self.transforms = transforms
self.transform = transform
self.target_transform = target_transform
# Collect image file paths
print("root:", self.root)
self.image_paths = sorted(self.root.rglob("*.jpg")) # Adjust file format if needed
if not self.image_paths:
raise ValueError(f"No images found in dataset directory: {root}")
if shuffle:
import random
random.shuffle(self.image_paths)
self.true_len = len(self.image_paths)
print(f"Loaded {self.true_len} images from {root}")
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, str]:
"""
Loads and returns an image, target, and filepath.
Args:
index (int): Dataset index.
Returns:
Tuple[torch.Tensor, torch.Tensor, str]: (image, target, filepath)
"""
adjusted_index = index % self.true_len # Avoid division by zero error
filepath = str(self.image_paths[adjusted_index])
# print("filepath:", filepath)
try:
image = Image.open(filepath).convert("RGB")
except Exception as e:
print(f"Error loading image {filepath}: {e}")
return self.__getitem__((index + 1) % self.true_len) # Skip to next valid image
if self.transform:
image = self.transform(image)
target = torch.zeros((1,)) # Modify if ADK20 has labels
if self.target_transform:
target = self.target_transform(target)
# raw_image = rgb_to_raw(filepath)
# after_raw = raw_to_rgb(raw_image)
# print("Img:", image)
# print(type(image), type(raw_image))
# print(type(image), type(after_raw), image.keys())
return image, target, filepath
# return raw_image, target, filepath
# return image, raw_image, target, filepath
def __len__(self) -> int:
return self.true_len
def rgb_to_raw(self, image_path, local_crops_number=6):
# print("Path:", image_path)
img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
if img is None:
return {}
if len(img.shape) == 3:
img_raw = img[:, :, 1]
else:
img_raw = img
if img_raw.dtype != np.uint16:
img_raw = (img_raw.astype(np.float32) / 255.0 * 65535).astype(np.uint16)
# Normalize the raw image to [0, 1]
img_raw = img_raw.astype(np.float32) / 65535.0
raw_tensor = torch.from_numpy(img_raw).unsqueeze(0) # Shape: [1, H, W]
output = {
"global_crops": [raw_tensor, raw_tensor], # Two global crops
"global_crops_teacher": [raw_tensor, raw_tensor],
"local_crops": [raw_tensor for _ in range(local_crops_number)],
"offsets": ()
}
# print("Type: ", type(rgb_to_raw))
return output

View File

@ -12,7 +12,9 @@ from torch.utils.data import Sampler
from .datasets import ImageNet, ImageNet22k
from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler
from .datasets import (
ADK20Dataset
)
logger = logging.getLogger("dinov2")
@ -49,15 +51,15 @@ def _parse_dataset_str(dataset_str: str):
for token in tokens[1:]:
key, value = token.split("=")
assert key in ("root", "extra", "split")
assert key in ("root", "extra", "split", "shuffle")
if key == "shuffle":
value = bool(int(value))
kwargs[key] = value
# if name == "HemaStandardDataset":
# class_ = HemaStandardDataset
if name == "ImageNet":
class_ = ImageNet
if "split" in kwargs:
kwargs["split"] = ImageNet.Split[kwargs["split"]]
elif name == "ImageNet22k":
class_ = ImageNet22k
class_ = ADK20Dataset
else:
raise ValueError(f'Unsupported dataset "{name}"')

View File

@ -0,0 +1,642 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Callable, List, Optional, Tuple
class BaseModule(nn.Module):
def __init__(self):
super(BaseModule, self).__init__()
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
def conv3x3(in_channels, out_channels, stride=1, padding=1):
"""3x3 convolution with no bias."""
return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=padding, bias=False)
from torch import Tensor
class Merge_block(BaseModule):
def __init__(self, fea_c, ada_c, mid_c, return_ada=True):
super(Merge_block, self).__init__()
self.fea_c = fea_c
self.ada_c = ada_c
# 784 - embedded dim + adapter_c
self.embeded_dim = 768
self.fc_1 = nn.Linear(self.embeded_dim + ada_c, mid_c)
self.fc_2 = nn.Linear(mid_c, self.embeded_dim)
self.return_ada = return_ada
if self.return_ada:
self.conv_3 = nn.Conv1d(mid_c, ada_c * 2, kernel_size=1) # 1D Conv instead of 3x3
else:
self.conv_3 = None
def forward(self, fea, adapter, ratio=1.0):
res = fea
# print("Before concatenation: ", fea.shape, adapter.shape, self.fea_c, self.ada_c)
fea = torch.cat([fea, adapter], dim=-1) # (B, seq_len, fea_c + ada_c)
B, seq_len, C = fea.shape
fea = fea.view(B * seq_len, C)
fea = self.fc_1(fea)
fea = fea.view(B, seq_len, -1)
ada = self.fc_2(fea)
fea_out = ratio * ada + res
if self.return_ada:
ada = self.conv_3(fea.permute(0, 2, 1))
return fea_out, ada.permute(0, 2, 1)
else:
return fea_out, None
def conv7x7(in_planes: int, out_planes: int, stride: int = 3, groups: int = 1, padding: int = 3, dilation: int = 1) -> nn.Conv2d:
"""7x7 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=7, stride=stride, padding=padding, groups=groups, bias=False, dilation=dilation)
class BasicBlock(nn.Module):
expansion: int = 1
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = nn.LayerNorm
if groups != 1 or base_width != 64:
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer([planes]) # Modify this to pass the correct shape
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer([planes]) # Modify this to pass the correct shape
self.downsample = downsample
self.stride = stride
def forward(self, x: Tensor) -> Tensor:
identity = x
# x = x.to(self.expand_channels.weight.dtype)
out = self.conv1(x)
# Reshape for LayerNorm
# C, H, W = out.shape
if out.dim() == 3:
out = out.unsqueeze(0)
out = out.permute(0, 2, 3, 1) # [B, C, H, W] -> [B, H, W, C]
out = self.bn1(out) # Apply LayerNorm on the channel dimension (last)
out = out.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
out = self.relu(out)
out = self.conv2(out)
# if out.dim() == 3:
# out = out.unsqueeze(0)
# Reshape for LayerNorm
# C, H, W = out.shape
out = out.permute(0, 2, 3, 1) # [B, C, H, W] -> [B, H, W, C]
out = self.bn2(out) # Apply LayerNorm on the channel dimension (last)
out = out.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Model_level_Adapeter(BaseModule):
def __init__(self, in_c=3, in_dim=8, w_lut=True):
super(Model_level_Adapeter, self).__init__()
self.conv_1 = conv3x3(in_c, in_c, 2)
self.conv_2 = conv3x3(in_c, in_c, 2)
self.conv_3 = conv3x3(in_c, in_c, 2)
self.w_lut = w_lut
if self.w_lut: # With LUT: I1, I2, I3, I4
self.conv_4 = conv3x3(in_c, in_c, 2)
self.uni_conv = conv7x7(4*in_c, in_dim, 2, padding=3)
else: # Without LUT: I1, I2, I3
self.uni_conv = conv7x7(3*in_c, in_dim, 2, padding=3)
# self.res_1 = BasicBlock(inplanes=in_dim, planes=in_dim)
# self.res_2 = BasicBlock(inplanes=in_dim, planes=in_dim)
def forward(self, IMGS):
if self.w_lut:
adapter = torch.cat([self.conv_1(IMGS[0]), self.conv_2(IMGS[1]), self.conv_3(IMGS[2]), self.conv_4(IMGS[3])], dim=1)
else:
adapter = torch.cat([self.conv_1(IMGS[0]), self.conv_2(IMGS[1]), self.conv_3(IMGS[2])], dim=1)
adapter = self.uni_conv(adapter)
# adapter = self.res_1(adapter)
# adapter = self.res_2(adapter)
return adapter
# class Model_level_Adapeter(BaseModule):
# def __init__(self, in_c=12, in_dim=16, w_lut=True):
# super(Model_level_Adapeter, self).__init__()
# self.conv_1 = conv3x3(in_c, in_c, 2)
# self.conv_2 = conv3x3(in_c, in_c, 2)
# self.conv_3 = conv3x3(in_c, in_c, 2)
# self.w_lut = w_lut
# if self.w_lut:
# self.conv_4 = conv3x3(in_c, in_c, 2)
# self.channel_reducer = conv7x7(272, in_dim, 2, padding=3)
# self.uni_conv = conv7x7(12, in_dim, 2, padding=3)
# else:
# self.uni_conv = conv7x7(3*in_c, in_dim, 2, padding=3)
# self.res_1 = BasicBlock(inplanes=in_dim, planes=in_dim)
# self.res_2 = BasicBlock(inplanes=in_dim, planes=in_dim)
# self.expand_channels = nn.Conv2d(3, 32, kernel_size=1, stride=1, padding=0).to(torch.float16)
# def forward(self, IMGS):
# device = IMGS[0].device
# print("DEVICE", device)
# IMGS = [img.to(torch.float16) for img in IMGS]
# reduce_channels = nn.Conv2d(32, 12, kernel_size=1, bias=False).to(device).to(torch.float16)
# IMGS = [img.to(self.expand_channels.weight.dtype) for img in IMGS]
# print("Types: ", IMGS[0].dtype, self.expand_channels.weight.dtype)
# IMGS = [reduce_channels(self.expand_channels(img)) for img in IMGS]
# print(f"After first reduction: {IMGS[0].shape}")
# temp_conv_1 = nn.Conv2d(12, 12, kernel_size=3, stride=2, padding=1, bias=False).to(device).to(torch.float16)
# temp_conv_2 = nn.Conv2d(12, 12, kernel_size=3, stride=2, padding=1, bias=False).to(device).to(torch.float16)
# temp_conv_3 = nn.Conv2d(12, 12, kernel_size=3, stride=2, padding=1, bias=False).to(device).to(torch.float16)
# print("HERE HERE")
# if self.w_lut:
# temp_conv_4 = nn.Conv2d(12, 12, kernel_size=3, stride=2, padding=1, bias=False).to(device).to(torch.float16)
# print("HERE HERE 22", len(IMGS))
# adapter = torch.cat([
# temp_conv_1(IMGS[0]),
# temp_conv_2(IMGS[1]),
# temp_conv_3(IMGS[2]),
# temp_conv_4(IMGS[3])
# ], dim=1)
# else:
# adapter = torch.cat([
# temp_conv_1(IMGS[0]),
# temp_conv_2(IMGS[1]),
# temp_conv_3(IMGS[2])
# ], dim=1)
# adapter = adapter.half()
# print("HERE HERE 333")
# adapter = self.uni_conv(adapter)
# print("HERE HERE 44")
# adapter = self.res_1(adapter) # Residual Block 1
# adapter = self.res_2(adapter) # Residual Block 2
# return adapter
import torch
import torch.nn as nn
import torch.nn.functional as F
# class Input_level_Adapeter(nn.Module):
# def __init__(self, mode='normal', lut_dim=32, k_size=3, w_lut=True, in_channels=3):
# """
# Args:
# mode (str): Operating mode. Can be 'normal' or another mode if you extend this module.
# lut_dim (int): The output channel dimension if using the LUT branch.
# k_size (int): Kernel size for the convolutional layers.
# w_lut (bool): Whether to use the LUT branch.
# in_channels (int): Number of input channels. Typically 3 for RGB/RAW images.
# """
# super(Input_level_Adapeter, self).__init__()
# self.mode = mode
# self.lut_dim = lut_dim
# self.k_size = k_size
# self.w_lut = w_lut
# # First convolutional block.
# self.conv1 = nn.Conv2d(in_channels, 16, kernel_size=k_size, padding=k_size // 2, bias=False)
# # self.bn1 = nn.LayerNorm(16)
# self.relu = nn.ReLU(inplace=True)
# # Second convolutional block.
# self.conv2 = nn.Conv2d(16, 32, kernel_size=k_size, padding=k_size // 2, bias=False)
# # self.bn2 = nn.LayerNorm(32)
# self.bn1 = nn.GroupNorm(1, 16)
# self.bn2 = nn.GroupNorm(1, 32)
# # If using LUT processing, map the 32-channel features to lut_dim channels.
# if self.w_lut:
# self.lut_conv = nn.Conv2d(32, lut_dim, kernel_size=1, bias=False)
# # Create two downsampling layers for multi-scale outputs.
# self.down1 = nn.Conv2d(32 if not w_lut else lut_dim,
# 32 if not w_lut else lut_dim,
# kernel_size=3, stride=2, padding=1, bias=False)
# self.down2 = nn.Conv2d(32 if not w_lut else lut_dim,
# 32 if not w_lut else lut_dim,
# kernel_size=3, stride=2, padding=1, bias=False)
# def forward(self, x):
# """
# Forward pass for the input-level adapter.
# Args:
# x (Tensor): Input image tensor of shape (B, in_channels, H, W).
# Returns:
# List[Tensor]: A list of feature maps at multiple scales. For example:
# [feat_full, feat_down1, feat_down2]
# where feat_down2 is the most downsampled feature used for adaptation.
# """
# # Initial conv block.
# out = self.conv1(x)
# out = self.bn1(out)
# out = self.relu(out)
# out = self.conv2(out)
# out = self.bn2(out)
# out = self.relu(out)
# # If enabled, adjust features via LUT branch.
# if self.w_lut:
# out = self.lut_conv(out)
# # Compute multi-scale features.
# feat_full = out # Original resolution feature.
# feat_down1 = self.relu(self.down1(feat_full)) # Downsampled by a factor of 2.
# feat_down2 = self.relu(self.down2(feat_down1)) # Downsampled further.
# # Return a list of features. In your transformer, you can pick the desired scale.
# return [feat_full, feat_down1, feat_down2]
class CustomLayerNorm(nn.Module):
def __init__(self, normalized_shape):
super().__init__()
self.norm = nn.LayerNorm(normalized_shape)
def forward(self, x):
# Input has shape [B, S, D] or [B, C, H, W]
if len(x.shape) == 3: # [B, S, D] or [B, C, S]
if x.shape[1] == self.norm.normalized_shape[0]:
# This is [B, C, S] format, need to transpose
x = x.transpose(1, 2) # Now [B, S, C]
x = self.norm(x)
x = x.transpose(1, 2) # Back to [B, C, S]
else:
# Already in [B, S, C] format
x = self.norm(x)
elif len(x.shape) == 4: # [B, C, H, W]
b, c, h, w = x.shape
# Reshape to [B, H*W, C]
x = x.permute(0, 2, 3, 1).reshape(b, h*w, c)
# Apply norm
x = self.norm(x)
# Reshape back to [B, C, H, W]
x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
return x
# Predictor P_K
class Kernel_Predictor(nn.Module):
def __init__(self, dim, mode='low', num_heads=1, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# Use provided scale factor or default to head_dim^-0.5
self.scale = qk_scale or head_dim ** -0.5
# Query Adaptive Learning (QAL)
self.q = nn.Parameter(torch.rand((1, 4, dim)), requires_grad=True)
# self.input_proj = nn.Conv2d(16, 3, kernel_size=1)
self.kv_downsample = nn.Sequential(
nn.Conv2d(3, dim // 8, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(1, dim // 8), # replaced BatchNorm2d with GroupNorm
nn.GELU(),
nn.Conv2d(dim // 8, dim // 4, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(1, dim // 4),
nn.GELU(),
nn.Conv2d(dim // 4, dim // 2, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(1, dim // 2),
nn.GELU(),
nn.Conv2d(dim // 2, dim, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(1, dim),
)
self.k = nn.Linear(dim, dim, bias=qkv_bias)
self.v = nn.Linear(dim, dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.down = nn.Linear(dim, 1)
self.softmax = nn.Softmax(dim=2)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
# Set basic parameters
if mode == 'low':
self.gain_base = nn.Parameter(torch.FloatTensor([3]), requires_grad=True)
else:
self.gain_base = nn.Parameter(torch.FloatTensor([1]), requires_grad=True)
self.r1_base = nn.Parameter(torch.FloatTensor([3]), requires_grad=False)
self.r2_base = nn.Parameter(torch.FloatTensor([2]), requires_grad=False)
def forward(self, x):
# print("X type: ", type(x), x.dim)
# x = x[0]
# x = self.input_proj(x)
# output = self.kv_downsample(x)
# print("Output type: ", type(output)) # This should print <class 'torch.Tensor'>, but it might print <class 'list'>
d_x = self.kv_downsample(x).flatten(2).transpose(1, 2) # [B, N, C]
B, N, C = d_x.shape
k = self.k(d_x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
v = self.v(d_x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
q = self.q.expand(B, -1, -1).view(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
out = (attn @ v).transpose(1, 2).reshape(B, 4, C)
out = self.proj(out)
out = self.proj_drop(out)
out = self.down(out).squeeze(-1)
out = torch.unbind(out, 1)
r1, r2, gain, sigma = out[0], out[1], out[2], out[3]
r1 = 0.1 * r1 + self.r1_base
r2 = 0.1 * r2 + self.r2_base
gain = gain + self.gain_base
return r1, r2, gain, self.sigmoid(sigma)
class Matrix_Predictor(nn.Module):
def __init__(self, dim, num_heads=1, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# Query Adaptive Learning (QAL)
self.q = nn.Parameter(torch.rand((1, 9 + 1, dim)), requires_grad=True)
self.kv_downsample = nn.Sequential(
nn.Conv2d(3, dim // 8, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(1, dim // 8),
nn.GELU(),
nn.Conv2d(dim // 8, dim // 4, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(1, dim // 4),
nn.GELU(),
nn.Conv2d(dim // 4, dim // 2, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(1, dim // 2),
nn.GELU(),
nn.Conv2d(dim // 2, dim, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(1, dim),
)
self.k = nn.Linear(dim, dim, bias=qkv_bias)
self.v = nn.Linear(dim, dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.down = nn.Linear(dim, 1)
self.softmax = nn.Softmax(dim=2)
self.relu = nn.ReLU()
self.ccm_base = nn.Parameter(torch.eye(3), requires_grad=False)
def forward(self, x):
d_x = self.kv_downsample(x).flatten(2).transpose(1, 2)
B, N, C = d_x.shape
k = self.k(d_x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
v = self.v(d_x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
q = self.q.expand(B, -1, -1).view(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
out = (attn @ v).transpose(1, 2).reshape(B, 9 + 1, C)
out = self.proj(out)
out = self.proj_drop(out)
out = self.down(out)
out, distance = out[:, :9, :], out[:, 9:, :].squeeze(-1)
out = out.view(B, 3, 3)
ccm_matrix = 0.1 * out + self.ccm_base
distance = self.relu(distance) + 1
return ccm_matrix, distance
# AAAI 2024 NILUT, we change the channel number to avoid much FLOPs
class NILUT(nn.Module):
"""
Simple residual coordinate-based neural network for fitting 3D LUTs
Official code: https://github.com/mv-lab/nilut
"""
def __init__(self, in_features=3, hidden_features=32, hidden_layers=3, out_features=3, res=True):
super().__init__()
self.res = res
self.net = []
self.net.append(nn.Linear(in_features, hidden_features))
self.net.append(nn.ReLU())
for _ in range(hidden_layers):
self.net.append(nn.Linear(hidden_features, hidden_features))
self.net.append(nn.Tanh())
self.net.append(nn.Linear(hidden_features, out_features))
if not self.res:
self.net.append(torch.nn.Sigmoid())
self.net = nn.Sequential(*self.net)
def forward(self, intensity):
output = self.net(intensity)
if self.res:
output = output + intensity
output = torch.clamp(output, 0.,1.)
return output
def _is_tensor_a_torch_image(x: Tensor) -> bool:
return x.ndim >= 2
def _assert_image_tensor(img: Tensor) -> None:
if not _is_tensor_a_torch_image(img):
raise TypeError("Tensor is not a torch image.")
def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor:
ksize_half = (kernel_size - 1) * 0.5
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size).to(sigma.device)
#print(x.device)
#print(sigma.device)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
kernel1d = pdf / pdf.sum()
return kernel1d
def _get_gaussian_kernel2d(
kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
) -> Tensor:
kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype)
kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype)
kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])
return kernel2d
def _cast_squeeze_in(img: Tensor, req_dtypes: List[torch.dtype]) -> Tuple[Tensor, bool, bool, torch.dtype]:
need_squeeze = False
# make image NCHW
if img.ndim < 4:
img = img.unsqueeze(dim=0)
need_squeeze = True
out_dtype = img.dtype
need_cast = False
if out_dtype not in req_dtypes:
need_cast = True
req_dtype = req_dtypes[0]
img = img.to(req_dtype)
return img, need_cast, need_squeeze, out_dtype
from torch.nn.functional import grid_sample, conv2d, interpolate, pad as torch_pad
def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtype: torch.dtype) -> Tensor:
if need_squeeze:
img = img.squeeze(dim=0)
if need_cast:
if out_dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
# it is better to round before cast
img = torch.round(img)
img = img.to(out_dtype)
return img
def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor:
if not (isinstance(img, torch.Tensor)):
raise TypeError(f"img should be Tensor. Got {type(img)}")
_assert_image_tensor(img)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device)
kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(
img,
[
kernel.dtype,
],
)
# padding = (left, right, top, bottom)
padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
img = torch_pad(img, padding, mode="reflect")
img = conv2d(img, kernel, groups=img.shape[-3])
img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
return img
def Gain_Denoise(I1, r1, r2, gain, sigma, k_size=3): # [9, 9] in LOD dataset, [3, 3] in other dataset
out = []
for i in range(I1.shape[0]):
I1_gain = gain[i] * I1[i,:,:,:]
blur = gaussian_blur(I1_gain, \
[k_size, k_size], \
[r1[i], r2[i]])
sharp = blur + sigma[i] * (I1[i,:,:,:] - blur)
out.append(sharp)
return torch.stack([out[i] for i in range(I1.shape[0])], dim=0)
def SoG_algo(img, p=1):
# https://library.imaging.org/admin/apis/public/api/ist/website/downloadArticle/cic/12/1/art00008
img = img.permute(1,2,0) # (C,H,W) --> (H,W,C)
img_P = torch.pow(img, p)
R_avg = torch.mean(img_P[:,:,0]) ** (1/p)
G_avg = torch.mean(img_P[:,:,1]) ** (1/p)
B_avg = torch.mean(img_P[:,:,2]) ** (1/p)
Avg = torch.mean(img_P) ** (1/p)
R_avg = R_avg / Avg
G_avg = G_avg / Avg
B_avg = B_avg / Avg
img_out = torch.stack([img[:,:,0]/R_avg, img[:,:,1]/G_avg, img[:,:,2]/B_avg], dim=-1)
return img_out
def WB_CCM(I2, ccm_matrix, distance):
out_I3 = []
out_I4 = []
for i in range(I2.shape[0]):
# SOG White Balance Algorithm
I3 = SoG_algo(I2[i,:,:,:], distance[i])
# Camera Color Matrix
I4 = torch.tensordot(I3, ccm_matrix[i,:,:], dims=[[-1], [-1]])
I4 = torch.clamp(I4, 1e-5, 1.0)
out_I3.append(I3)
out_I4.append(I4)
return torch.stack([out_I3[i] for i in range(I2.shape[0])], dim=0), \
torch.stack([out_I4[i] for i in range(I2.shape[0])], dim=0)
class VitInputLevelAdapter(nn.Module):
def __init__(self, mode='normal', lut_dim=32, out='all', k_size=3, w_lut=True):
"""
Args:
mode (str): Operating mode, e.g. 'normal' for normal/over-exposure or 'low' for low-light.
lut_dim (int): Dimensionality for the implicit neural LUT.
k_size (int): Kernel size for the denoising operation.
w_lut (bool): Whether to use the implicit 3D Look-Up Table.
"""
super(VitInputLevelAdapter, self).__init__()
# These submodules predict transformation parameters from the input image.
self.Predictor_K = Kernel_Predictor(dim=64, mode=mode)
self.Predictor_M = Matrix_Predictor(dim=64)
self.w_lut = w_lut
if self.w_lut:
self.LUT = NILUT(hidden_features=lut_dim)
print("hidden_features", lut_dim)
# self.LUT = nn.Linear(224, 32)
self.k_size = k_size
self.out = out
def forward(self, I1):
# (1). I1 --> I2: Denoise & Enhancement & Sharpen
r1, r2, gain, sigma = self.Predictor_K(I1)
I2 = Gain_Denoise(I1, r1, r2, gain, sigma, k_size=self.k_size) # (B,C,H,W)
I2 = torch.clamp(I2, 1e-5, 1.0) # normal & over-exposure
ccm_matrix, distance = self.Predictor_M(I2)
# (2). I2 --> I3: White Balance, Shade of Gray
# (3). I3 --> I4: Camera Colour Matrix Transformation
I3, I4 = WB_CCM(I2, ccm_matrix, distance) # (B,H,W,C)
if self.w_lut:
# (4). I4 --> I5: Implicit Neural LUT
I5 = self.LUT(I4).permute(0,3,1,2)
if self.out == 'all': # return all features
return [I1, I2, I3.permute(0,3,1,2), I4.permute(0,3,1,2), I5]
else: # only return I5
return [I5]
else:
if self.out == 'all':
return [I1, I2, I3.permute(0,3,1,2), I4.permute(0,3,1,2)]
else:
return [I4.permute(0,3,1,2)]

View File

@ -18,7 +18,8 @@ import torch.utils.checkpoint
from torch.nn.init import trunc_normal_
from dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
from dinov2.models.help import Merge_block, Model_level_Adapeter
from dinov2.models.help import VitInputLevelAdapter as Input_level_Adapeter
logger = logging.getLogger("dinov2")
@ -65,6 +66,17 @@ class DinoVisionTransformer(nn.Module):
num_register_tokens=0,
interpolate_antialias=False,
interpolate_offset=0.1,
# RAW adapter parameters
w_lut=True,
light_mode='normal',
lut_dim=32,
k_size=3,
merge_ratio=1.0,
model_adapter_path='/home/paperspace/Documents/nika_space/ECCV_RAW_Adapter/extracted_model_adapter_weights.pth',
input_level_adapter_path='/home/paperspace/Documents/nika_space/ECCV_RAW_Adapter/extracted_pre_encoder_weights.pth',
fea_c_s = [384, 768, 1920],
ada_c_s = [16, 32, 64],
mid_c_s = [384, 576, 768],
):
"""
Args:
@ -103,6 +115,13 @@ class DinoVisionTransformer(nn.Module):
self.interpolate_antialias = interpolate_antialias
self.interpolate_offset = interpolate_offset
# RAW adapter configuration
self.w_lut = w_lut
self.light_mode = light_mode
self.lut_dim = lut_dim
self.k_size = k_size
self.merge_ratio = merge_ratio
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
@ -166,6 +185,29 @@ class DinoVisionTransformer(nn.Module):
self.head = nn.Identity()
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
# Initialize RAW adapter
if self.w_lut:
self.pre_encoder = Input_level_Adapeter(mode=light_mode, lut_dim=lut_dim, k_size=k_size, w_lut=w_lut)
for param in self.pre_encoder.parameters():
param.requires_grad_(True)
# self.model_adapter = Model_level_Adapeter(in_c=in_chans, w_lut=w_lut)
self.model_adapter = Model_level_Adapeter(in_c=3, in_dim=ada_c_s[0], w_lut=self.w_lut)
if model_adapter_path is not None:
print("Loading model adapter:", model_adapter_path)
adapter_state = torch.load(model_adapter_path, map_location="cpu")
self.model_adapter.load_state_dict(adapter_state, strict=False)
if input_level_adapter_path is not None:
print("Loading input-level adapter:", input_level_adapter_path)
adapter_state = torch.load(input_level_adapter_path, map_location="cpu")
self.pre_encoder.load_state_dict(adapter_state)
self.merge_1 = Merge_block(fea_c=fea_c_s[0], ada_c=ada_c_s[0], mid_c=mid_c_s[0], return_ada=True)
self.merge_2 = Merge_block(fea_c=fea_c_s[1], ada_c=ada_c_s[1], mid_c=mid_c_s[1], return_ada=True)
self.merge_3 = Merge_block(fea_c=fea_c_s[2], ada_c=ada_c_s[2], mid_c=mid_c_s[2], return_ada=False)
self.merge_blocks = [self.merge_1, self.merge_2, self.merge_3]
self.init_weights()
@ -175,6 +217,19 @@ class DinoVisionTransformer(nn.Module):
if self.register_tokens is not None:
nn.init.normal_(self.register_tokens, std=1e-6)
named_apply(init_weights_vit_timm, self)
if hasattr(self, "pre_encoder"):
print("Init weights for pre-encoder")
self.pre_encoder.apply(init_weights_vit_timm)
if hasattr(self, "model_adapter"):
print("Init weights for model adapter")
self.model_adapter.apply(init_weights_vit_timm)
if hasattr(self, "merge_blocks"):
print("Init weights for merge blocks")
for block in self.merge_blocks:
block.apply(init_weights_vit_timm)
def interpolate_pos_encoding(self, x, w, h):
previous_dtype = x.dtype
@ -212,6 +267,15 @@ class DinoVisionTransformer(nn.Module):
def prepare_tokens_with_masks(self, x, masks=None):
B, nc, w, h = x.shape
x_raw = self.pre_encoder(x)
if self.w_lut: # I1, I2, I3, I4
ada = self.model_adapter([x_raw[0], x_raw[1], x_raw[2], x_raw[3]])
else: # I1, I2, I3
ada = self.model_adapter([x_raw[0], x_raw[1], x_raw[2]])
x = x_raw[-1]
x = self.patch_embed(x)
if masks is not None:
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
@ -229,12 +293,34 @@ class DinoVisionTransformer(nn.Module):
dim=1,
)
return x
ada = ada.reshape(ada.shape[0], ada.shape[1], -1)
batch_size, channels, features = ada.shape
target_seq_len = x.shape[1]
linear_proj = nn.Linear(features, target_seq_len).to(ada.device).to(ada.dtype)
ada = linear_proj(ada).permute(0, 2, 1)
return x, ada
def forward_features_list(self, x_list, masks_list):
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
for blk in self.blocks:
# x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
x_s = []
ada_list = []
for x, masks in zip(x_list, masks_list):
x_, ada = self.prepare_tokens_with_masks(x, masks)
x_s.append(x_)
ada_list.append(ada)
x = x_s
for i, blk in enumerate(self.blocks):
x = blk(x)
if self.w_lut and ada is not None and i < len(self.merge_blocks):
x_ada_pairs = [self.merge_blocks[i](x_i, ada_i, ratio=self.merge_ratio) for x_i, ada_i in zip(x, ada_list)]
x, ada_list = map(list, zip(*x_ada_pairs))
all_x = x
output = []
@ -255,10 +341,12 @@ class DinoVisionTransformer(nn.Module):
if isinstance(x, list):
return self.forward_features_list(x, masks)
x = self.prepare_tokens_with_masks(x, masks)
x, ada = self.prepare_tokens_with_masks(x, masks)
for blk in self.blocks:
for i, blk in enumerate(self.blocks):
x = blk(x)
if self.w_lut and ada is not None and i < len(self.merge_blocks):
x, ada = self.merge_blocks[i](x, ada, ratio=self.merge_ratio)
x_norm = self.norm(x)
return {

View File

@ -0,0 +1,139 @@
import numpy as np
import cv2
import argparse
from pathlib import Path
import torch
from PIL import Image
import numpy as np
def rgb_to_raw(image_path="", img=None):
if img is None:
img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
else:
img = np.array(img)
if img is None:
raise ValueError("Error loading image. Please check the path.")
# print(img.shape, img.dtype)
# cv2.imwrite("new.jpg", img)
if len(img.shape) == 3:
img_raw = img[:, :, 1] # Extract green channel as a naive RAW simulation
else:
img_raw = img
if img_raw.dtype != np.uint16:
img_raw = (img_raw.astype(np.float32) / 255.0 * 65535).astype(np.uint16)
# cv2.imwrite("new2.jpg", img)
return img_raw
# def rgb_to_raw(image_path, local_crops_number=6):
# """
# Reads an image from disk, simulates a RAW image by extracting (for example)
# the green channel, and returns a dictionary formatted like the output
# of your DataAugmentationDINO pipeline.
# """
# # Read the image using OpenCV (unchanged mode)
# img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
# if img is None:
# raise ValueError("Error loading image. Please check the path.")
# # If the image has three channels, simulate RAW by taking the green channel.
# if len(img.shape) == 3:
# img_raw = img[:, :, 1] # Using the green channel as a naive RAW simulation
# else:
# img_raw = img
# # Convert to uint16 if needed.
# if img_raw.dtype != np.uint16:
# img_raw = (img_raw.astype(np.float32) / 255.0 * 65535).astype(np.uint16)
# # Normalize the raw image to [0, 1] (as float32)
# img_raw = img_raw.astype(np.float32) / 65535.0
# # Convert the raw image to a torch tensor.
# # Assuming the raw image is single channel, add a channel dimension.
# raw_tensor = torch.from_numpy(img_raw).unsqueeze(0) # Shape: [1, H, W]
# # For consistency, we simulate two global crops (for student and teacher)
# # and several local crops. Here we simply use the same raw tensor for each crop.
# output = {
# "global_crops": [raw_tensor, raw_tensor], # Two global crops
# "global_crops_teacher": [raw_tensor, raw_tensor],
# "local_crops": [raw_tensor for _ in range(local_crops_number)],
# "offsets": () # Keeping offsets empty as before
# }
# # print("Type: ", type(rgb_to_raw))
# return output
import numpy as np
import cv2
def raw_to_rgb(raw_array, pattern='RGGB', image_size=(256, 256), bits=16):
"""
Convert RAW sensor data to RGB image with improved handling of various bit depths
and white balance correction.
Args:
raw_array: NumPy array containing RAW sensor data
pattern: Bayer pattern ('RGGB', 'BGGR', 'GRBG', 'GBRG')
image_size: Tuple of (height, width)
bits: Bit depth of the RAW data (typically 12, 14, or 16)
Returns:
RGB image as numpy array
"""
if not isinstance(raw_array, np.ndarray):
raise TypeError("Input must be a NumPy array.")
# total_pixels = np.prod(image_size)
# if raw_array.size != total_pixels:
# raise ValueError(f"Expected raw array size {total_pixels}, but got {raw_array.size}")
# raw_image = raw_array.reshape(image_size)
max_value = 2**bits - 1
raw_image = raw_array
if raw_image.dtype != np.uint8:
raw_image = (raw_image / max_value * 255).astype(np.uint8)
bayer_patterns = {
'RGGB': cv2.COLOR_BayerBG2BGR,
'BGGR': cv2.COLOR_BayerGB2BGR,
'GRBG': cv2.COLOR_BayerGR2BGR,
'GBRG': cv2.COLOR_BayerRG2BGR
}
rgb_image = cv2.demosaicing(raw_image, bayer_patterns[pattern])
rgb_image_float = rgb_image.astype(float)
for i in range(3):
channel = rgb_image_float[:,:,i]
mean_val = np.mean(channel)
if mean_val > 0:
channel *= 128 / mean_val
rgb_image_float[:,:,i] = channel
rgb_image = np.clip(rgb_image_float, 0, 255)
# cv2.imwrite('output.jpg', rgb_image)
pil_image = Image.fromarray(rgb_image.astype(np.uint8), mode='RGB')
return pil_image
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert RGB image to RAW format.")
parser.add_argument("image_path", type=str, help="Path to the input image")
args = parser.parse_args()
raw = rgb_to_raw(args.image_path)
print(raw, type(raw), raw.shape)
# rgb = raw_to_rgb(raw)
rgb = raw_to_rgb(raw, image_size=(512, 683))
print(rgb)