added input, model, merge block for raw
parent
e1277af2ba
commit
9c1bde505b
|
@ -0,0 +1,98 @@
|
|||
# base config: dinov2/configs/train/vitl14.yaml
|
||||
compute_precision:
|
||||
grad_scaler: true
|
||||
teacher:
|
||||
backbone:
|
||||
sharding_strategy: SHARD_GRAD_OP
|
||||
mixed_precision:
|
||||
param_dtype: fp16
|
||||
reduce_dtype: fp16
|
||||
buffer_dtype: fp32
|
||||
dino_head:
|
||||
sharding_strategy: SHARD_GRAD_OP
|
||||
mixed_precision:
|
||||
param_dtype: fp16
|
||||
reduce_dtype: fp16
|
||||
buffer_dtype: fp32
|
||||
ibot_head:
|
||||
sharding_strategy: SHARD_GRAD_OP
|
||||
mixed_precision:
|
||||
param_dtype: fp16
|
||||
reduce_dtype: fp16
|
||||
buffer_dtype: fp32
|
||||
student:
|
||||
backbone:
|
||||
sharding_strategy: SHARD_GRAD_OP
|
||||
mixed_precision:
|
||||
param_dtype: fp16
|
||||
reduce_dtype: fp16
|
||||
buffer_dtype: fp32
|
||||
dino_head:
|
||||
sharding_strategy: SHARD_GRAD_OP
|
||||
mixed_precision:
|
||||
param_dtype: fp16
|
||||
reduce_dtype: fp32
|
||||
buffer_dtype: fp32
|
||||
ibot_head:
|
||||
sharding_strategy: SHARD_GRAD_OP
|
||||
mixed_precision:
|
||||
param_dtype: fp16
|
||||
reduce_dtype: fp32
|
||||
buffer_dtype: fp32
|
||||
|
||||
|
||||
ibot:
|
||||
separate_head: true
|
||||
#head_n_prototypes: 131072
|
||||
data_transform: "default"
|
||||
train:
|
||||
batch_size_per_gpu: 16 #vitg 26+, vitl: 56, vits:152, vitb:120 for 8 node
|
||||
num_workers: 1
|
||||
OFFICIAL_EPOCH_LENGTH: 100 # 1250
|
||||
dataset_path: ImageNet:root=/home/paperspace/Documents/nika_space/ADE20K/ADEChallengeData2016/images/training_raw/
|
||||
centering: sinkhorn_knopp
|
||||
|
||||
drop_path_rate: 0.4
|
||||
ffn_layer: swiglufused
|
||||
block_chunks: 0 # for distributed training
|
||||
num_register_tokens: 0 # 0 for no register tokens
|
||||
|
||||
teacher:
|
||||
momentum_teacher: 0.994
|
||||
optim:
|
||||
epochs: 20 # 500
|
||||
weight_decay_end: 0.2
|
||||
base_lr: 0.001 # learning rate for a batch size of 1024
|
||||
warmup_epochs: 20 # 80
|
||||
layerwise_decay: 1.0
|
||||
|
||||
evaluation:
|
||||
eval_period_iterations: 1000
|
||||
|
||||
# adapt to model architecture
|
||||
# ---------------------------
|
||||
# config for vit
|
||||
# "dinov2_vits14","dinov2_vitb14","dinov2_vitl14","dinov2_vitg14"
|
||||
|
||||
student:
|
||||
arch: vit_base
|
||||
patch_size: 14
|
||||
crops:
|
||||
global_crops_scale:
|
||||
- 0.32 #0.32 default
|
||||
- 1.0
|
||||
local_crops_size: 98
|
||||
local_crops_number: 1 #!!! bit hacky, 1 indicates NO LOCAL CROPS !!!
|
||||
dino:
|
||||
head_bottleneck_dim: 256 #vits: 256, vitl: 384
|
||||
smooth_rank_loss_weight: 0.0 #doesnt help
|
||||
|
||||
# ---------------------------
|
||||
# config for vim_tiny
|
||||
#student:
|
||||
# arch: vim_tiny
|
||||
# patch_size: 16
|
||||
#crops:
|
||||
# local_crops_size: 96
|
||||
#dino:
|
||||
# head_bottleneck_dim: 256
|
|
@ -5,3 +5,4 @@
|
|||
|
||||
from .image_net import ImageNet
|
||||
from .image_net_22k import ImageNet22k
|
||||
from .my_dataset import ADK20Dataset
|
|
@ -0,0 +1,117 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Tuple, List
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from PIL import Image
|
||||
from dinov2.train.rgb_to_raw import rgb_to_raw, raw_to_rgb
|
||||
|
||||
class ADK20Dataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
transforms: Optional[Callable] = None,
|
||||
transform: Optional[Callable] = None,
|
||||
target_transform: Optional[Callable] = None,
|
||||
shuffle: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
ADK20 Dataset for image classification.
|
||||
|
||||
Args:
|
||||
root (str): Path to dataset directory.
|
||||
transforms (Callable, optional): Combined image and target transformations.
|
||||
transform (Callable, optional): Image transformations.
|
||||
target_transform (Callable, optional): Target transformations.
|
||||
shuffle (bool, optional): If True, shuffles the dataset. Defaults to False.
|
||||
"""
|
||||
self.root = Path(root)
|
||||
self.transforms = transforms
|
||||
self.transform = transform
|
||||
self.target_transform = target_transform
|
||||
|
||||
# Collect image file paths
|
||||
print("root:", self.root)
|
||||
self.image_paths = sorted(self.root.rglob("*.jpg")) # Adjust file format if needed
|
||||
if not self.image_paths:
|
||||
raise ValueError(f"No images found in dataset directory: {root}")
|
||||
|
||||
if shuffle:
|
||||
import random
|
||||
random.shuffle(self.image_paths)
|
||||
|
||||
self.true_len = len(self.image_paths)
|
||||
print(f"Loaded {self.true_len} images from {root}")
|
||||
|
||||
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, str]:
|
||||
"""
|
||||
Loads and returns an image, target, and filepath.
|
||||
|
||||
Args:
|
||||
index (int): Dataset index.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor, str]: (image, target, filepath)
|
||||
"""
|
||||
adjusted_index = index % self.true_len # Avoid division by zero error
|
||||
filepath = str(self.image_paths[adjusted_index])
|
||||
# print("filepath:", filepath)
|
||||
try:
|
||||
image = Image.open(filepath).convert("RGB")
|
||||
except Exception as e:
|
||||
print(f"Error loading image {filepath}: {e}")
|
||||
return self.__getitem__((index + 1) % self.true_len) # Skip to next valid image
|
||||
|
||||
if self.transform:
|
||||
image = self.transform(image)
|
||||
|
||||
target = torch.zeros((1,)) # Modify if ADK20 has labels
|
||||
|
||||
if self.target_transform:
|
||||
target = self.target_transform(target)
|
||||
|
||||
# raw_image = rgb_to_raw(filepath)
|
||||
# after_raw = raw_to_rgb(raw_image)
|
||||
# print("Img:", image)
|
||||
# print(type(image), type(raw_image))
|
||||
# print(type(image), type(after_raw), image.keys())
|
||||
return image, target, filepath
|
||||
# return raw_image, target, filepath
|
||||
# return image, raw_image, target, filepath
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.true_len
|
||||
|
||||
def rgb_to_raw(self, image_path, local_crops_number=6):
|
||||
# print("Path:", image_path)
|
||||
|
||||
img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
|
||||
if img is None:
|
||||
return {}
|
||||
|
||||
if len(img.shape) == 3:
|
||||
img_raw = img[:, :, 1]
|
||||
else:
|
||||
img_raw = img
|
||||
|
||||
if img_raw.dtype != np.uint16:
|
||||
img_raw = (img_raw.astype(np.float32) / 255.0 * 65535).astype(np.uint16)
|
||||
|
||||
# Normalize the raw image to [0, 1]
|
||||
img_raw = img_raw.astype(np.float32) / 65535.0
|
||||
|
||||
|
||||
raw_tensor = torch.from_numpy(img_raw).unsqueeze(0) # Shape: [1, H, W]
|
||||
|
||||
output = {
|
||||
"global_crops": [raw_tensor, raw_tensor], # Two global crops
|
||||
"global_crops_teacher": [raw_tensor, raw_tensor],
|
||||
"local_crops": [raw_tensor for _ in range(local_crops_number)],
|
||||
"offsets": ()
|
||||
}
|
||||
# print("Type: ", type(rgb_to_raw))
|
||||
return output
|
||||
|
||||
|
|
@ -12,7 +12,9 @@ from torch.utils.data import Sampler
|
|||
|
||||
from .datasets import ImageNet, ImageNet22k
|
||||
from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler
|
||||
|
||||
from .datasets import (
|
||||
ADK20Dataset
|
||||
)
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
@ -49,15 +51,15 @@ def _parse_dataset_str(dataset_str: str):
|
|||
|
||||
for token in tokens[1:]:
|
||||
key, value = token.split("=")
|
||||
assert key in ("root", "extra", "split")
|
||||
assert key in ("root", "extra", "split", "shuffle")
|
||||
if key == "shuffle":
|
||||
value = bool(int(value))
|
||||
kwargs[key] = value
|
||||
|
||||
# if name == "HemaStandardDataset":
|
||||
# class_ = HemaStandardDataset
|
||||
if name == "ImageNet":
|
||||
class_ = ImageNet
|
||||
if "split" in kwargs:
|
||||
kwargs["split"] = ImageNet.Split[kwargs["split"]]
|
||||
elif name == "ImageNet22k":
|
||||
class_ = ImageNet22k
|
||||
class_ = ADK20Dataset
|
||||
else:
|
||||
raise ValueError(f'Unsupported dataset "{name}"')
|
||||
|
||||
|
|
|
@ -0,0 +1,642 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
class BaseModule(nn.Module):
|
||||
def __init__(self):
|
||||
super(BaseModule, self).__init__()
|
||||
|
||||
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
def conv3x3(in_channels, out_channels, stride=1, padding=1):
|
||||
"""3x3 convolution with no bias."""
|
||||
return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=padding, bias=False)
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
class Merge_block(BaseModule):
|
||||
def __init__(self, fea_c, ada_c, mid_c, return_ada=True):
|
||||
super(Merge_block, self).__init__()
|
||||
|
||||
self.fea_c = fea_c
|
||||
self.ada_c = ada_c
|
||||
# 784 - embedded dim + adapter_c
|
||||
self.embeded_dim = 768
|
||||
self.fc_1 = nn.Linear(self.embeded_dim + ada_c, mid_c)
|
||||
self.fc_2 = nn.Linear(mid_c, self.embeded_dim)
|
||||
self.return_ada = return_ada
|
||||
|
||||
if self.return_ada:
|
||||
self.conv_3 = nn.Conv1d(mid_c, ada_c * 2, kernel_size=1) # 1D Conv instead of 3x3
|
||||
else:
|
||||
self.conv_3 = None
|
||||
|
||||
def forward(self, fea, adapter, ratio=1.0):
|
||||
res = fea
|
||||
# print("Before concatenation: ", fea.shape, adapter.shape, self.fea_c, self.ada_c)
|
||||
|
||||
fea = torch.cat([fea, adapter], dim=-1) # (B, seq_len, fea_c + ada_c)
|
||||
|
||||
B, seq_len, C = fea.shape
|
||||
fea = fea.view(B * seq_len, C)
|
||||
fea = self.fc_1(fea)
|
||||
fea = fea.view(B, seq_len, -1)
|
||||
ada = self.fc_2(fea)
|
||||
fea_out = ratio * ada + res
|
||||
|
||||
if self.return_ada:
|
||||
|
||||
ada = self.conv_3(fea.permute(0, 2, 1))
|
||||
return fea_out, ada.permute(0, 2, 1)
|
||||
else:
|
||||
return fea_out, None
|
||||
|
||||
|
||||
|
||||
def conv7x7(in_planes: int, out_planes: int, stride: int = 3, groups: int = 1, padding: int = 3, dilation: int = 1) -> nn.Conv2d:
|
||||
"""7x7 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=7, stride=stride, padding=padding, groups=groups, bias=False, dilation=dilation)
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion: int = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inplanes: int,
|
||||
planes: int,
|
||||
stride: int = 1,
|
||||
downsample: Optional[nn.Module] = None,
|
||||
groups: int = 1,
|
||||
base_width: int = 64,
|
||||
dilation: int = 1,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.LayerNorm
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
|
||||
if dilation > 1:
|
||||
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
||||
|
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = norm_layer([planes]) # Modify this to pass the correct shape
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = norm_layer([planes]) # Modify this to pass the correct shape
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
identity = x
|
||||
# x = x.to(self.expand_channels.weight.dtype)
|
||||
out = self.conv1(x)
|
||||
# Reshape for LayerNorm
|
||||
# C, H, W = out.shape
|
||||
if out.dim() == 3:
|
||||
out = out.unsqueeze(0)
|
||||
out = out.permute(0, 2, 3, 1) # [B, C, H, W] -> [B, H, W, C]
|
||||
out = self.bn1(out) # Apply LayerNorm on the channel dimension (last)
|
||||
out = out.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
# if out.dim() == 3:
|
||||
# out = out.unsqueeze(0)
|
||||
# Reshape for LayerNorm
|
||||
# C, H, W = out.shape
|
||||
out = out.permute(0, 2, 3, 1) # [B, C, H, W] -> [B, H, W, C]
|
||||
out = self.bn2(out) # Apply LayerNorm on the channel dimension (last)
|
||||
out = out.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
class Model_level_Adapeter(BaseModule):
|
||||
def __init__(self, in_c=3, in_dim=8, w_lut=True):
|
||||
super(Model_level_Adapeter, self).__init__()
|
||||
self.conv_1 = conv3x3(in_c, in_c, 2)
|
||||
self.conv_2 = conv3x3(in_c, in_c, 2)
|
||||
self.conv_3 = conv3x3(in_c, in_c, 2)
|
||||
self.w_lut = w_lut
|
||||
if self.w_lut: # With LUT: I1, I2, I3, I4
|
||||
self.conv_4 = conv3x3(in_c, in_c, 2)
|
||||
self.uni_conv = conv7x7(4*in_c, in_dim, 2, padding=3)
|
||||
else: # Without LUT: I1, I2, I3
|
||||
self.uni_conv = conv7x7(3*in_c, in_dim, 2, padding=3)
|
||||
|
||||
# self.res_1 = BasicBlock(inplanes=in_dim, planes=in_dim)
|
||||
# self.res_2 = BasicBlock(inplanes=in_dim, planes=in_dim)
|
||||
|
||||
def forward(self, IMGS):
|
||||
if self.w_lut:
|
||||
adapter = torch.cat([self.conv_1(IMGS[0]), self.conv_2(IMGS[1]), self.conv_3(IMGS[2]), self.conv_4(IMGS[3])], dim=1)
|
||||
|
||||
else:
|
||||
adapter = torch.cat([self.conv_1(IMGS[0]), self.conv_2(IMGS[1]), self.conv_3(IMGS[2])], dim=1)
|
||||
|
||||
adapter = self.uni_conv(adapter)
|
||||
# adapter = self.res_1(adapter)
|
||||
# adapter = self.res_2(adapter)
|
||||
return adapter
|
||||
|
||||
# class Model_level_Adapeter(BaseModule):
|
||||
# def __init__(self, in_c=12, in_dim=16, w_lut=True):
|
||||
# super(Model_level_Adapeter, self).__init__()
|
||||
# self.conv_1 = conv3x3(in_c, in_c, 2)
|
||||
# self.conv_2 = conv3x3(in_c, in_c, 2)
|
||||
# self.conv_3 = conv3x3(in_c, in_c, 2)
|
||||
# self.w_lut = w_lut
|
||||
# if self.w_lut:
|
||||
# self.conv_4 = conv3x3(in_c, in_c, 2)
|
||||
# self.channel_reducer = conv7x7(272, in_dim, 2, padding=3)
|
||||
# self.uni_conv = conv7x7(12, in_dim, 2, padding=3)
|
||||
# else:
|
||||
# self.uni_conv = conv7x7(3*in_c, in_dim, 2, padding=3)
|
||||
|
||||
# self.res_1 = BasicBlock(inplanes=in_dim, planes=in_dim)
|
||||
# self.res_2 = BasicBlock(inplanes=in_dim, planes=in_dim)
|
||||
# self.expand_channels = nn.Conv2d(3, 32, kernel_size=1, stride=1, padding=0).to(torch.float16)
|
||||
|
||||
# def forward(self, IMGS):
|
||||
# device = IMGS[0].device
|
||||
# print("DEVICE", device)
|
||||
# IMGS = [img.to(torch.float16) for img in IMGS]
|
||||
|
||||
# reduce_channels = nn.Conv2d(32, 12, kernel_size=1, bias=False).to(device).to(torch.float16)
|
||||
|
||||
# IMGS = [img.to(self.expand_channels.weight.dtype) for img in IMGS]
|
||||
# print("Types: ", IMGS[0].dtype, self.expand_channels.weight.dtype)
|
||||
|
||||
# IMGS = [reduce_channels(self.expand_channels(img)) for img in IMGS]
|
||||
|
||||
# print(f"After first reduction: {IMGS[0].shape}")
|
||||
|
||||
|
||||
# temp_conv_1 = nn.Conv2d(12, 12, kernel_size=3, stride=2, padding=1, bias=False).to(device).to(torch.float16)
|
||||
# temp_conv_2 = nn.Conv2d(12, 12, kernel_size=3, stride=2, padding=1, bias=False).to(device).to(torch.float16)
|
||||
# temp_conv_3 = nn.Conv2d(12, 12, kernel_size=3, stride=2, padding=1, bias=False).to(device).to(torch.float16)
|
||||
# print("HERE HERE")
|
||||
# if self.w_lut:
|
||||
# temp_conv_4 = nn.Conv2d(12, 12, kernel_size=3, stride=2, padding=1, bias=False).to(device).to(torch.float16)
|
||||
# print("HERE HERE 22", len(IMGS))
|
||||
# adapter = torch.cat([
|
||||
# temp_conv_1(IMGS[0]),
|
||||
# temp_conv_2(IMGS[1]),
|
||||
# temp_conv_3(IMGS[2]),
|
||||
# temp_conv_4(IMGS[3])
|
||||
# ], dim=1)
|
||||
# else:
|
||||
# adapter = torch.cat([
|
||||
# temp_conv_1(IMGS[0]),
|
||||
# temp_conv_2(IMGS[1]),
|
||||
# temp_conv_3(IMGS[2])
|
||||
# ], dim=1)
|
||||
|
||||
# adapter = adapter.half()
|
||||
# print("HERE HERE 333")
|
||||
# adapter = self.uni_conv(adapter)
|
||||
# print("HERE HERE 44")
|
||||
# adapter = self.res_1(adapter) # Residual Block 1
|
||||
# adapter = self.res_2(adapter) # Residual Block 2
|
||||
# return adapter
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# class Input_level_Adapeter(nn.Module):
|
||||
# def __init__(self, mode='normal', lut_dim=32, k_size=3, w_lut=True, in_channels=3):
|
||||
# """
|
||||
# Args:
|
||||
# mode (str): Operating mode. Can be 'normal' or another mode if you extend this module.
|
||||
# lut_dim (int): The output channel dimension if using the LUT branch.
|
||||
# k_size (int): Kernel size for the convolutional layers.
|
||||
# w_lut (bool): Whether to use the LUT branch.
|
||||
# in_channels (int): Number of input channels. Typically 3 for RGB/RAW images.
|
||||
# """
|
||||
# super(Input_level_Adapeter, self).__init__()
|
||||
# self.mode = mode
|
||||
# self.lut_dim = lut_dim
|
||||
# self.k_size = k_size
|
||||
# self.w_lut = w_lut
|
||||
|
||||
# # First convolutional block.
|
||||
# self.conv1 = nn.Conv2d(in_channels, 16, kernel_size=k_size, padding=k_size // 2, bias=False)
|
||||
# # self.bn1 = nn.LayerNorm(16)
|
||||
# self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
# # Second convolutional block.
|
||||
# self.conv2 = nn.Conv2d(16, 32, kernel_size=k_size, padding=k_size // 2, bias=False)
|
||||
# # self.bn2 = nn.LayerNorm(32)
|
||||
# self.bn1 = nn.GroupNorm(1, 16)
|
||||
# self.bn2 = nn.GroupNorm(1, 32)
|
||||
# # If using LUT processing, map the 32-channel features to lut_dim channels.
|
||||
# if self.w_lut:
|
||||
# self.lut_conv = nn.Conv2d(32, lut_dim, kernel_size=1, bias=False)
|
||||
|
||||
# # Create two downsampling layers for multi-scale outputs.
|
||||
# self.down1 = nn.Conv2d(32 if not w_lut else lut_dim,
|
||||
# 32 if not w_lut else lut_dim,
|
||||
# kernel_size=3, stride=2, padding=1, bias=False)
|
||||
# self.down2 = nn.Conv2d(32 if not w_lut else lut_dim,
|
||||
# 32 if not w_lut else lut_dim,
|
||||
# kernel_size=3, stride=2, padding=1, bias=False)
|
||||
|
||||
# def forward(self, x):
|
||||
# """
|
||||
# Forward pass for the input-level adapter.
|
||||
# Args:
|
||||
# x (Tensor): Input image tensor of shape (B, in_channels, H, W).
|
||||
# Returns:
|
||||
# List[Tensor]: A list of feature maps at multiple scales. For example:
|
||||
# [feat_full, feat_down1, feat_down2]
|
||||
# where feat_down2 is the most downsampled feature used for adaptation.
|
||||
# """
|
||||
# # Initial conv block.
|
||||
# out = self.conv1(x)
|
||||
# out = self.bn1(out)
|
||||
# out = self.relu(out)
|
||||
|
||||
# out = self.conv2(out)
|
||||
# out = self.bn2(out)
|
||||
# out = self.relu(out)
|
||||
|
||||
# # If enabled, adjust features via LUT branch.
|
||||
# if self.w_lut:
|
||||
# out = self.lut_conv(out)
|
||||
|
||||
# # Compute multi-scale features.
|
||||
# feat_full = out # Original resolution feature.
|
||||
# feat_down1 = self.relu(self.down1(feat_full)) # Downsampled by a factor of 2.
|
||||
# feat_down2 = self.relu(self.down2(feat_down1)) # Downsampled further.
|
||||
|
||||
# # Return a list of features. In your transformer, you can pick the desired scale.
|
||||
# return [feat_full, feat_down1, feat_down2]
|
||||
|
||||
|
||||
class CustomLayerNorm(nn.Module):
|
||||
def __init__(self, normalized_shape):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(normalized_shape)
|
||||
|
||||
def forward(self, x):
|
||||
# Input has shape [B, S, D] or [B, C, H, W]
|
||||
if len(x.shape) == 3: # [B, S, D] or [B, C, S]
|
||||
if x.shape[1] == self.norm.normalized_shape[0]:
|
||||
# This is [B, C, S] format, need to transpose
|
||||
x = x.transpose(1, 2) # Now [B, S, C]
|
||||
x = self.norm(x)
|
||||
x = x.transpose(1, 2) # Back to [B, C, S]
|
||||
else:
|
||||
# Already in [B, S, C] format
|
||||
x = self.norm(x)
|
||||
elif len(x.shape) == 4: # [B, C, H, W]
|
||||
b, c, h, w = x.shape
|
||||
# Reshape to [B, H*W, C]
|
||||
x = x.permute(0, 2, 3, 1).reshape(b, h*w, c)
|
||||
# Apply norm
|
||||
x = self.norm(x)
|
||||
# Reshape back to [B, C, H, W]
|
||||
x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
|
||||
return x
|
||||
|
||||
# Predictor P_K
|
||||
class Kernel_Predictor(nn.Module):
|
||||
def __init__(self, dim, mode='low', num_heads=1, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
# Use provided scale factor or default to head_dim^-0.5
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
# Query Adaptive Learning (QAL)
|
||||
self.q = nn.Parameter(torch.rand((1, 4, dim)), requires_grad=True)
|
||||
# self.input_proj = nn.Conv2d(16, 3, kernel_size=1)
|
||||
self.kv_downsample = nn.Sequential(
|
||||
nn.Conv2d(3, dim // 8, kernel_size=3, stride=2, padding=1),
|
||||
nn.GroupNorm(1, dim // 8), # replaced BatchNorm2d with GroupNorm
|
||||
nn.GELU(),
|
||||
nn.Conv2d(dim // 8, dim // 4, kernel_size=3, stride=2, padding=1),
|
||||
nn.GroupNorm(1, dim // 4),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(dim // 4, dim // 2, kernel_size=3, stride=2, padding=1),
|
||||
nn.GroupNorm(1, dim // 2),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(dim // 2, dim, kernel_size=3, stride=2, padding=1),
|
||||
nn.GroupNorm(1, dim),
|
||||
)
|
||||
|
||||
self.k = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
self.v = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.down = nn.Linear(dim, 1)
|
||||
self.softmax = nn.Softmax(dim=2)
|
||||
self.relu = nn.ReLU()
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
# Set basic parameters
|
||||
if mode == 'low':
|
||||
self.gain_base = nn.Parameter(torch.FloatTensor([3]), requires_grad=True)
|
||||
else:
|
||||
self.gain_base = nn.Parameter(torch.FloatTensor([1]), requires_grad=True)
|
||||
self.r1_base = nn.Parameter(torch.FloatTensor([3]), requires_grad=False)
|
||||
self.r2_base = nn.Parameter(torch.FloatTensor([2]), requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
# print("X type: ", type(x), x.dim)
|
||||
# x = x[0]
|
||||
# x = self.input_proj(x)
|
||||
# output = self.kv_downsample(x)
|
||||
# print("Output type: ", type(output)) # This should print <class 'torch.Tensor'>, but it might print <class 'list'>
|
||||
|
||||
d_x = self.kv_downsample(x).flatten(2).transpose(1, 2) # [B, N, C]
|
||||
B, N, C = d_x.shape
|
||||
k = self.k(d_x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
||||
v = self.v(d_x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
||||
q = self.q.expand(B, -1, -1).view(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
out = (attn @ v).transpose(1, 2).reshape(B, 4, C)
|
||||
out = self.proj(out)
|
||||
out = self.proj_drop(out)
|
||||
out = self.down(out).squeeze(-1)
|
||||
out = torch.unbind(out, 1)
|
||||
r1, r2, gain, sigma = out[0], out[1], out[2], out[3]
|
||||
r1 = 0.1 * r1 + self.r1_base
|
||||
r2 = 0.1 * r2 + self.r2_base
|
||||
gain = gain + self.gain_base
|
||||
return r1, r2, gain, self.sigmoid(sigma)
|
||||
|
||||
|
||||
class Matrix_Predictor(nn.Module):
|
||||
def __init__(self, dim, num_heads=1, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
# Query Adaptive Learning (QAL)
|
||||
self.q = nn.Parameter(torch.rand((1, 9 + 1, dim)), requires_grad=True)
|
||||
self.kv_downsample = nn.Sequential(
|
||||
nn.Conv2d(3, dim // 8, kernel_size=3, stride=2, padding=1),
|
||||
nn.GroupNorm(1, dim // 8),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(dim // 8, dim // 4, kernel_size=3, stride=2, padding=1),
|
||||
nn.GroupNorm(1, dim // 4),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(dim // 4, dim // 2, kernel_size=3, stride=2, padding=1),
|
||||
nn.GroupNorm(1, dim // 2),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(dim // 2, dim, kernel_size=3, stride=2, padding=1),
|
||||
nn.GroupNorm(1, dim),
|
||||
)
|
||||
self.k = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
self.v = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.down = nn.Linear(dim, 1)
|
||||
self.softmax = nn.Softmax(dim=2)
|
||||
self.relu = nn.ReLU()
|
||||
self.ccm_base = nn.Parameter(torch.eye(3), requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
d_x = self.kv_downsample(x).flatten(2).transpose(1, 2)
|
||||
B, N, C = d_x.shape
|
||||
k = self.k(d_x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
||||
v = self.v(d_x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
||||
q = self.q.expand(B, -1, -1).view(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
out = (attn @ v).transpose(1, 2).reshape(B, 9 + 1, C)
|
||||
out = self.proj(out)
|
||||
out = self.proj_drop(out)
|
||||
out = self.down(out)
|
||||
out, distance = out[:, :9, :], out[:, 9:, :].squeeze(-1)
|
||||
out = out.view(B, 3, 3)
|
||||
ccm_matrix = 0.1 * out + self.ccm_base
|
||||
distance = self.relu(distance) + 1
|
||||
return ccm_matrix, distance
|
||||
|
||||
# AAAI 2024 NILUT, we change the channel number to avoid much FLOPs
|
||||
class NILUT(nn.Module):
|
||||
"""
|
||||
Simple residual coordinate-based neural network for fitting 3D LUTs
|
||||
Official code: https://github.com/mv-lab/nilut
|
||||
"""
|
||||
def __init__(self, in_features=3, hidden_features=32, hidden_layers=3, out_features=3, res=True):
|
||||
super().__init__()
|
||||
|
||||
self.res = res
|
||||
self.net = []
|
||||
self.net.append(nn.Linear(in_features, hidden_features))
|
||||
self.net.append(nn.ReLU())
|
||||
|
||||
for _ in range(hidden_layers):
|
||||
self.net.append(nn.Linear(hidden_features, hidden_features))
|
||||
self.net.append(nn.Tanh())
|
||||
|
||||
self.net.append(nn.Linear(hidden_features, out_features))
|
||||
if not self.res:
|
||||
self.net.append(torch.nn.Sigmoid())
|
||||
|
||||
self.net = nn.Sequential(*self.net)
|
||||
|
||||
def forward(self, intensity):
|
||||
output = self.net(intensity)
|
||||
if self.res:
|
||||
output = output + intensity
|
||||
output = torch.clamp(output, 0.,1.)
|
||||
return output
|
||||
|
||||
|
||||
def _is_tensor_a_torch_image(x: Tensor) -> bool:
|
||||
return x.ndim >= 2
|
||||
|
||||
def _assert_image_tensor(img: Tensor) -> None:
|
||||
if not _is_tensor_a_torch_image(img):
|
||||
raise TypeError("Tensor is not a torch image.")
|
||||
|
||||
|
||||
def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor:
|
||||
ksize_half = (kernel_size - 1) * 0.5
|
||||
|
||||
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size).to(sigma.device)
|
||||
#print(x.device)
|
||||
#print(sigma.device)
|
||||
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
|
||||
kernel1d = pdf / pdf.sum()
|
||||
|
||||
return kernel1d
|
||||
|
||||
def _get_gaussian_kernel2d(
|
||||
kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
|
||||
) -> Tensor:
|
||||
kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype)
|
||||
kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype)
|
||||
kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])
|
||||
return kernel2d
|
||||
|
||||
def _cast_squeeze_in(img: Tensor, req_dtypes: List[torch.dtype]) -> Tuple[Tensor, bool, bool, torch.dtype]:
|
||||
need_squeeze = False
|
||||
# make image NCHW
|
||||
if img.ndim < 4:
|
||||
img = img.unsqueeze(dim=0)
|
||||
need_squeeze = True
|
||||
|
||||
out_dtype = img.dtype
|
||||
need_cast = False
|
||||
if out_dtype not in req_dtypes:
|
||||
need_cast = True
|
||||
req_dtype = req_dtypes[0]
|
||||
img = img.to(req_dtype)
|
||||
return img, need_cast, need_squeeze, out_dtype
|
||||
|
||||
from torch.nn.functional import grid_sample, conv2d, interpolate, pad as torch_pad
|
||||
|
||||
|
||||
def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtype: torch.dtype) -> Tensor:
|
||||
if need_squeeze:
|
||||
img = img.squeeze(dim=0)
|
||||
|
||||
if need_cast:
|
||||
if out_dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
|
||||
# it is better to round before cast
|
||||
img = torch.round(img)
|
||||
img = img.to(out_dtype)
|
||||
|
||||
return img
|
||||
|
||||
def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor:
|
||||
if not (isinstance(img, torch.Tensor)):
|
||||
raise TypeError(f"img should be Tensor. Got {type(img)}")
|
||||
|
||||
_assert_image_tensor(img)
|
||||
|
||||
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
|
||||
kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device)
|
||||
kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
|
||||
|
||||
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(
|
||||
img,
|
||||
[
|
||||
kernel.dtype,
|
||||
],
|
||||
)
|
||||
|
||||
# padding = (left, right, top, bottom)
|
||||
padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
|
||||
img = torch_pad(img, padding, mode="reflect")
|
||||
img = conv2d(img, kernel, groups=img.shape[-3])
|
||||
|
||||
img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
|
||||
return img
|
||||
|
||||
|
||||
def Gain_Denoise(I1, r1, r2, gain, sigma, k_size=3): # [9, 9] in LOD dataset, [3, 3] in other dataset
|
||||
out = []
|
||||
for i in range(I1.shape[0]):
|
||||
I1_gain = gain[i] * I1[i,:,:,:]
|
||||
blur = gaussian_blur(I1_gain, \
|
||||
[k_size, k_size], \
|
||||
[r1[i], r2[i]])
|
||||
sharp = blur + sigma[i] * (I1[i,:,:,:] - blur)
|
||||
out.append(sharp)
|
||||
return torch.stack([out[i] for i in range(I1.shape[0])], dim=0)
|
||||
|
||||
|
||||
def SoG_algo(img, p=1):
|
||||
# https://library.imaging.org/admin/apis/public/api/ist/website/downloadArticle/cic/12/1/art00008
|
||||
img = img.permute(1,2,0) # (C,H,W) --> (H,W,C)
|
||||
|
||||
img_P = torch.pow(img, p)
|
||||
|
||||
R_avg = torch.mean(img_P[:,:,0]) ** (1/p)
|
||||
G_avg = torch.mean(img_P[:,:,1]) ** (1/p)
|
||||
B_avg = torch.mean(img_P[:,:,2]) ** (1/p)
|
||||
|
||||
Avg = torch.mean(img_P) ** (1/p)
|
||||
|
||||
R_avg = R_avg / Avg
|
||||
G_avg = G_avg / Avg
|
||||
B_avg = B_avg / Avg
|
||||
|
||||
img_out = torch.stack([img[:,:,0]/R_avg, img[:,:,1]/G_avg, img[:,:,2]/B_avg], dim=-1)
|
||||
|
||||
return img_out
|
||||
|
||||
def WB_CCM(I2, ccm_matrix, distance):
|
||||
out_I3 = []
|
||||
out_I4 = []
|
||||
for i in range(I2.shape[0]):
|
||||
# SOG White Balance Algorithm
|
||||
I3 = SoG_algo(I2[i,:,:,:], distance[i])
|
||||
|
||||
# Camera Color Matrix
|
||||
I4 = torch.tensordot(I3, ccm_matrix[i,:,:], dims=[[-1], [-1]])
|
||||
I4 = torch.clamp(I4, 1e-5, 1.0)
|
||||
|
||||
out_I3.append(I3)
|
||||
out_I4.append(I4)
|
||||
|
||||
return torch.stack([out_I3[i] for i in range(I2.shape[0])], dim=0), \
|
||||
torch.stack([out_I4[i] for i in range(I2.shape[0])], dim=0)
|
||||
|
||||
class VitInputLevelAdapter(nn.Module):
|
||||
def __init__(self, mode='normal', lut_dim=32, out='all', k_size=3, w_lut=True):
|
||||
"""
|
||||
Args:
|
||||
mode (str): Operating mode, e.g. 'normal' for normal/over-exposure or 'low' for low-light.
|
||||
lut_dim (int): Dimensionality for the implicit neural LUT.
|
||||
k_size (int): Kernel size for the denoising operation.
|
||||
w_lut (bool): Whether to use the implicit 3D Look-Up Table.
|
||||
"""
|
||||
super(VitInputLevelAdapter, self).__init__()
|
||||
# These submodules predict transformation parameters from the input image.
|
||||
self.Predictor_K = Kernel_Predictor(dim=64, mode=mode)
|
||||
self.Predictor_M = Matrix_Predictor(dim=64)
|
||||
self.w_lut = w_lut
|
||||
if self.w_lut:
|
||||
self.LUT = NILUT(hidden_features=lut_dim)
|
||||
print("hidden_features", lut_dim)
|
||||
# self.LUT = nn.Linear(224, 32)
|
||||
self.k_size = k_size
|
||||
self.out = out
|
||||
|
||||
def forward(self, I1):
|
||||
# (1). I1 --> I2: Denoise & Enhancement & Sharpen
|
||||
r1, r2, gain, sigma = self.Predictor_K(I1)
|
||||
I2 = Gain_Denoise(I1, r1, r2, gain, sigma, k_size=self.k_size) # (B,C,H,W)
|
||||
I2 = torch.clamp(I2, 1e-5, 1.0) # normal & over-exposure
|
||||
|
||||
ccm_matrix, distance = self.Predictor_M(I2)
|
||||
# (2). I2 --> I3: White Balance, Shade of Gray
|
||||
# (3). I3 --> I4: Camera Colour Matrix Transformation
|
||||
I3, I4 = WB_CCM(I2, ccm_matrix, distance) # (B,H,W,C)
|
||||
|
||||
if self.w_lut:
|
||||
# (4). I4 --> I5: Implicit Neural LUT
|
||||
I5 = self.LUT(I4).permute(0,3,1,2)
|
||||
|
||||
if self.out == 'all': # return all features
|
||||
return [I1, I2, I3.permute(0,3,1,2), I4.permute(0,3,1,2), I5]
|
||||
else: # only return I5
|
||||
return [I5]
|
||||
|
||||
else:
|
||||
if self.out == 'all':
|
||||
return [I1, I2, I3.permute(0,3,1,2), I4.permute(0,3,1,2)]
|
||||
else:
|
||||
return [I4.permute(0,3,1,2)]
|
|
@ -18,7 +18,8 @@ import torch.utils.checkpoint
|
|||
from torch.nn.init import trunc_normal_
|
||||
|
||||
from dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
|
||||
|
||||
from dinov2.models.help import Merge_block, Model_level_Adapeter
|
||||
from dinov2.models.help import VitInputLevelAdapter as Input_level_Adapeter
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
@ -65,6 +66,17 @@ class DinoVisionTransformer(nn.Module):
|
|||
num_register_tokens=0,
|
||||
interpolate_antialias=False,
|
||||
interpolate_offset=0.1,
|
||||
# RAW adapter parameters
|
||||
w_lut=True,
|
||||
light_mode='normal',
|
||||
lut_dim=32,
|
||||
k_size=3,
|
||||
merge_ratio=1.0,
|
||||
model_adapter_path='/home/paperspace/Documents/nika_space/ECCV_RAW_Adapter/extracted_model_adapter_weights.pth',
|
||||
input_level_adapter_path='/home/paperspace/Documents/nika_space/ECCV_RAW_Adapter/extracted_pre_encoder_weights.pth',
|
||||
fea_c_s = [384, 768, 1920],
|
||||
ada_c_s = [16, 32, 64],
|
||||
mid_c_s = [384, 576, 768],
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -103,6 +115,13 @@ class DinoVisionTransformer(nn.Module):
|
|||
self.interpolate_antialias = interpolate_antialias
|
||||
self.interpolate_offset = interpolate_offset
|
||||
|
||||
# RAW adapter configuration
|
||||
self.w_lut = w_lut
|
||||
self.light_mode = light_mode
|
||||
self.lut_dim = lut_dim
|
||||
self.k_size = k_size
|
||||
self.merge_ratio = merge_ratio
|
||||
|
||||
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
|
@ -166,6 +185,29 @@ class DinoVisionTransformer(nn.Module):
|
|||
self.head = nn.Identity()
|
||||
|
||||
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
||||
|
||||
# Initialize RAW adapter
|
||||
if self.w_lut:
|
||||
self.pre_encoder = Input_level_Adapeter(mode=light_mode, lut_dim=lut_dim, k_size=k_size, w_lut=w_lut)
|
||||
for param in self.pre_encoder.parameters():
|
||||
param.requires_grad_(True)
|
||||
|
||||
# self.model_adapter = Model_level_Adapeter(in_c=in_chans, w_lut=w_lut)
|
||||
self.model_adapter = Model_level_Adapeter(in_c=3, in_dim=ada_c_s[0], w_lut=self.w_lut)
|
||||
if model_adapter_path is not None:
|
||||
print("Loading model adapter:", model_adapter_path)
|
||||
adapter_state = torch.load(model_adapter_path, map_location="cpu")
|
||||
self.model_adapter.load_state_dict(adapter_state, strict=False)
|
||||
if input_level_adapter_path is not None:
|
||||
print("Loading input-level adapter:", input_level_adapter_path)
|
||||
adapter_state = torch.load(input_level_adapter_path, map_location="cpu")
|
||||
self.pre_encoder.load_state_dict(adapter_state)
|
||||
|
||||
self.merge_1 = Merge_block(fea_c=fea_c_s[0], ada_c=ada_c_s[0], mid_c=mid_c_s[0], return_ada=True)
|
||||
self.merge_2 = Merge_block(fea_c=fea_c_s[1], ada_c=ada_c_s[1], mid_c=mid_c_s[1], return_ada=True)
|
||||
self.merge_3 = Merge_block(fea_c=fea_c_s[2], ada_c=ada_c_s[2], mid_c=mid_c_s[2], return_ada=False)
|
||||
self.merge_blocks = [self.merge_1, self.merge_2, self.merge_3]
|
||||
|
||||
|
||||
self.init_weights()
|
||||
|
||||
|
@ -175,6 +217,19 @@ class DinoVisionTransformer(nn.Module):
|
|||
if self.register_tokens is not None:
|
||||
nn.init.normal_(self.register_tokens, std=1e-6)
|
||||
named_apply(init_weights_vit_timm, self)
|
||||
|
||||
if hasattr(self, "pre_encoder"):
|
||||
print("Init weights for pre-encoder")
|
||||
self.pre_encoder.apply(init_weights_vit_timm)
|
||||
|
||||
if hasattr(self, "model_adapter"):
|
||||
print("Init weights for model adapter")
|
||||
self.model_adapter.apply(init_weights_vit_timm)
|
||||
|
||||
if hasattr(self, "merge_blocks"):
|
||||
print("Init weights for merge blocks")
|
||||
for block in self.merge_blocks:
|
||||
block.apply(init_weights_vit_timm)
|
||||
|
||||
def interpolate_pos_encoding(self, x, w, h):
|
||||
previous_dtype = x.dtype
|
||||
|
@ -212,6 +267,15 @@ class DinoVisionTransformer(nn.Module):
|
|||
|
||||
def prepare_tokens_with_masks(self, x, masks=None):
|
||||
B, nc, w, h = x.shape
|
||||
|
||||
x_raw = self.pre_encoder(x)
|
||||
if self.w_lut: # I1, I2, I3, I4
|
||||
ada = self.model_adapter([x_raw[0], x_raw[1], x_raw[2], x_raw[3]])
|
||||
else: # I1, I2, I3
|
||||
ada = self.model_adapter([x_raw[0], x_raw[1], x_raw[2]])
|
||||
|
||||
x = x_raw[-1]
|
||||
|
||||
x = self.patch_embed(x)
|
||||
if masks is not None:
|
||||
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
||||
|
@ -229,12 +293,34 @@ class DinoVisionTransformer(nn.Module):
|
|||
dim=1,
|
||||
)
|
||||
|
||||
return x
|
||||
ada = ada.reshape(ada.shape[0], ada.shape[1], -1)
|
||||
batch_size, channels, features = ada.shape
|
||||
target_seq_len = x.shape[1]
|
||||
linear_proj = nn.Linear(features, target_seq_len).to(ada.device).to(ada.dtype)
|
||||
ada = linear_proj(ada).permute(0, 2, 1)
|
||||
|
||||
|
||||
return x, ada
|
||||
|
||||
def forward_features_list(self, x_list, masks_list):
|
||||
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
||||
for blk in self.blocks:
|
||||
# x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
||||
|
||||
x_s = []
|
||||
ada_list = []
|
||||
for x, masks in zip(x_list, masks_list):
|
||||
x_, ada = self.prepare_tokens_with_masks(x, masks)
|
||||
x_s.append(x_)
|
||||
ada_list.append(ada)
|
||||
|
||||
|
||||
x = x_s
|
||||
|
||||
for i, blk in enumerate(self.blocks):
|
||||
x = blk(x)
|
||||
|
||||
if self.w_lut and ada is not None and i < len(self.merge_blocks):
|
||||
x_ada_pairs = [self.merge_blocks[i](x_i, ada_i, ratio=self.merge_ratio) for x_i, ada_i in zip(x, ada_list)]
|
||||
x, ada_list = map(list, zip(*x_ada_pairs))
|
||||
|
||||
all_x = x
|
||||
output = []
|
||||
|
@ -255,10 +341,12 @@ class DinoVisionTransformer(nn.Module):
|
|||
if isinstance(x, list):
|
||||
return self.forward_features_list(x, masks)
|
||||
|
||||
x = self.prepare_tokens_with_masks(x, masks)
|
||||
x, ada = self.prepare_tokens_with_masks(x, masks)
|
||||
|
||||
for blk in self.blocks:
|
||||
for i, blk in enumerate(self.blocks):
|
||||
x = blk(x)
|
||||
if self.w_lut and ada is not None and i < len(self.merge_blocks):
|
||||
x, ada = self.merge_blocks[i](x, ada, ratio=self.merge_ratio)
|
||||
|
||||
x_norm = self.norm(x)
|
||||
return {
|
||||
|
|
|
@ -0,0 +1,139 @@
|
|||
import numpy as np
|
||||
import cv2
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import torch
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
def rgb_to_raw(image_path="", img=None):
|
||||
if img is None:
|
||||
img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
|
||||
else:
|
||||
img = np.array(img)
|
||||
|
||||
if img is None:
|
||||
raise ValueError("Error loading image. Please check the path.")
|
||||
|
||||
# print(img.shape, img.dtype)
|
||||
# cv2.imwrite("new.jpg", img)
|
||||
if len(img.shape) == 3:
|
||||
img_raw = img[:, :, 1] # Extract green channel as a naive RAW simulation
|
||||
else:
|
||||
img_raw = img
|
||||
|
||||
|
||||
if img_raw.dtype != np.uint16:
|
||||
img_raw = (img_raw.astype(np.float32) / 255.0 * 65535).astype(np.uint16)
|
||||
|
||||
# cv2.imwrite("new2.jpg", img)
|
||||
return img_raw
|
||||
|
||||
# def rgb_to_raw(image_path, local_crops_number=6):
|
||||
# """
|
||||
# Reads an image from disk, simulates a RAW image by extracting (for example)
|
||||
# the green channel, and returns a dictionary formatted like the output
|
||||
# of your DataAugmentationDINO pipeline.
|
||||
# """
|
||||
# # Read the image using OpenCV (unchanged mode)
|
||||
# img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
|
||||
# if img is None:
|
||||
# raise ValueError("Error loading image. Please check the path.")
|
||||
|
||||
# # If the image has three channels, simulate RAW by taking the green channel.
|
||||
# if len(img.shape) == 3:
|
||||
# img_raw = img[:, :, 1] # Using the green channel as a naive RAW simulation
|
||||
# else:
|
||||
# img_raw = img
|
||||
|
||||
# # Convert to uint16 if needed.
|
||||
# if img_raw.dtype != np.uint16:
|
||||
# img_raw = (img_raw.astype(np.float32) / 255.0 * 65535).astype(np.uint16)
|
||||
|
||||
# # Normalize the raw image to [0, 1] (as float32)
|
||||
# img_raw = img_raw.astype(np.float32) / 65535.0
|
||||
|
||||
# # Convert the raw image to a torch tensor.
|
||||
# # Assuming the raw image is single channel, add a channel dimension.
|
||||
# raw_tensor = torch.from_numpy(img_raw).unsqueeze(0) # Shape: [1, H, W]
|
||||
|
||||
# # For consistency, we simulate two global crops (for student and teacher)
|
||||
# # and several local crops. Here we simply use the same raw tensor for each crop.
|
||||
# output = {
|
||||
# "global_crops": [raw_tensor, raw_tensor], # Two global crops
|
||||
# "global_crops_teacher": [raw_tensor, raw_tensor],
|
||||
# "local_crops": [raw_tensor for _ in range(local_crops_number)],
|
||||
# "offsets": () # Keeping offsets empty as before
|
||||
# }
|
||||
# # print("Type: ", type(rgb_to_raw))
|
||||
# return output
|
||||
|
||||
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
def raw_to_rgb(raw_array, pattern='RGGB', image_size=(256, 256), bits=16):
|
||||
"""
|
||||
Convert RAW sensor data to RGB image with improved handling of various bit depths
|
||||
and white balance correction.
|
||||
|
||||
Args:
|
||||
raw_array: NumPy array containing RAW sensor data
|
||||
pattern: Bayer pattern ('RGGB', 'BGGR', 'GRBG', 'GBRG')
|
||||
image_size: Tuple of (height, width)
|
||||
bits: Bit depth of the RAW data (typically 12, 14, or 16)
|
||||
|
||||
Returns:
|
||||
RGB image as numpy array
|
||||
"""
|
||||
if not isinstance(raw_array, np.ndarray):
|
||||
raise TypeError("Input must be a NumPy array.")
|
||||
|
||||
# total_pixels = np.prod(image_size)
|
||||
# if raw_array.size != total_pixels:
|
||||
# raise ValueError(f"Expected raw array size {total_pixels}, but got {raw_array.size}")
|
||||
|
||||
# raw_image = raw_array.reshape(image_size)
|
||||
max_value = 2**bits - 1
|
||||
raw_image = raw_array
|
||||
if raw_image.dtype != np.uint8:
|
||||
raw_image = (raw_image / max_value * 255).astype(np.uint8)
|
||||
|
||||
bayer_patterns = {
|
||||
'RGGB': cv2.COLOR_BayerBG2BGR,
|
||||
'BGGR': cv2.COLOR_BayerGB2BGR,
|
||||
'GRBG': cv2.COLOR_BayerGR2BGR,
|
||||
'GBRG': cv2.COLOR_BayerRG2BGR
|
||||
}
|
||||
|
||||
rgb_image = cv2.demosaicing(raw_image, bayer_patterns[pattern])
|
||||
|
||||
rgb_image_float = rgb_image.astype(float)
|
||||
for i in range(3):
|
||||
channel = rgb_image_float[:,:,i]
|
||||
mean_val = np.mean(channel)
|
||||
if mean_val > 0:
|
||||
channel *= 128 / mean_val
|
||||
rgb_image_float[:,:,i] = channel
|
||||
|
||||
rgb_image = np.clip(rgb_image_float, 0, 255)
|
||||
# cv2.imwrite('output.jpg', rgb_image)
|
||||
pil_image = Image.fromarray(rgb_image.astype(np.uint8), mode='RGB')
|
||||
|
||||
|
||||
return pil_image
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert RGB image to RAW format.")
|
||||
parser.add_argument("image_path", type=str, help="Path to the input image")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
raw = rgb_to_raw(args.image_path)
|
||||
print(raw, type(raw), raw.shape)
|
||||
# rgb = raw_to_rgb(raw)
|
||||
rgb = raw_to_rgb(raw, image_size=(512, 683))
|
||||
|
||||
print(rgb)
|
Loading…
Reference in New Issue