pull/511/head
Veronikkkka 2025-03-20 19:04:32 +00:00
parent 753b990c64
commit 20db539127
3 changed files with 465 additions and 33 deletions

View File

@ -60,22 +60,22 @@ class DataAugmentationDINO(object):
) )
# color distorsions / blurring # color distorsions / blurring
color_jittering = transforms.Compose( # color_jittering = transforms.Compose(
[ # [
transforms.RandomApply( # transforms.RandomApply(
[transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)], # [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
p=0.8, # p=0.8,
), # ),
transforms.RandomGrayscale(p=0.2), # transforms.RandomGrayscale(p=0.2),
] # ]
) # )
global_transfo1_extra = GaussianBlur(p=1.0) global_transfo1_extra = GaussianBlur(p=1.0)
global_transfo2_extra = transforms.Compose( global_transfo2_extra = transforms.Compose(
[ [
GaussianBlur(p=0.1), GaussianBlur(p=0.1),
transforms.RandomSolarize(threshold=128, p=0.2), # transforms.RandomSolarize(threshold=128, p=0.2),
] ]
) )
@ -85,13 +85,13 @@ class DataAugmentationDINO(object):
self.normalize = transforms.Compose( self.normalize = transforms.Compose(
[ [
transforms.ToTensor(), transforms.ToTensor(),
make_normalize_transform(), # make_normalize_transform(),
] ]
) )
self.global_transfo1 = transforms.Compose([color_jittering, global_transfo1_extra, self.normalize]) self.global_transfo1 = transforms.Compose([global_transfo1_extra, self.normalize])
self.global_transfo2 = transforms.Compose([color_jittering, global_transfo2_extra, self.normalize]) self.global_transfo2 = transforms.Compose([global_transfo2_extra, self.normalize])
self.local_transfo = transforms.Compose([color_jittering, local_transfo_extra, self.normalize]) self.local_transfo = transforms.Compose([local_transfo_extra, self.normalize])
def __call__(self, image): def __call__(self, image):
output = {} output = {}

View File

@ -1,27 +1,30 @@
import os import os
import random
from pathlib import Path from pathlib import Path
from typing import Callable, Optional, Tuple, List from typing import Callable, Optional, Tuple, List
import cv2
import numpy as np
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
import numpy as np
import cv2
from PIL import Image from PIL import Image
from dinov2.train.rgb_to_raw import rgb_to_raw, raw_to_rgb
class ADK20Dataset(Dataset): class ADK20Dataset(Dataset):
def __init__( def __init__(
self, self,
root: str, root: str,
annotations_file: str = "/home/paperspace/Documents/nika_space/ADE20K/ADEChallengeData2016/sceneCategories.txt",
transforms: Optional[Callable] = None, transforms: Optional[Callable] = None,
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
shuffle: bool = False, shuffle: bool = False,
) -> None: ) -> None:
""" """
ADK20 Dataset for image classification. ADK20 Dataset for image classification with labels.
Args: Args:
root (str): Path to dataset directory. root (str): Path to dataset directory.
annotations_file (str): Path to the annotations file containing image IDs and labels.
transforms (Callable, optional): Combined image and target transformations. transforms (Callable, optional): Combined image and target transformations.
transform (Callable, optional): Image transformations. transform (Callable, optional): Image transformations.
target_transform (Callable, optional): Target transformations. target_transform (Callable, optional): Target transformations.
@ -31,19 +34,69 @@ class ADK20Dataset(Dataset):
self.transforms = transforms self.transforms = transforms
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
# Collect image file paths
print("root:", self.root) print("root:", self.root)
self.image_paths = sorted(self.root.rglob("*.jpg")) # Adjust file format if needed self.image_paths = sorted(list(self.root.rglob("*.jpg")) + list(self.root.rglob("*.JPEG")))
if not self.image_paths: if not self.image_paths:
raise ValueError(f"No images found in dataset directory: {root}") raise ValueError(f"No images found in dataset directory: {root}")
# Load annotations
self.labels = {}
self.class_to_idx = {}
self.idx_to_class = {}
self._load_annotations(annotations_file)
# Filter image paths to only include those with annotations
self.image_paths = [p for p in self.image_paths if self._get_image_id(p) in self.labels]
if shuffle: if shuffle:
import random import random
random.shuffle(self.image_paths) random.shuffle(self.image_paths)
self.true_len = len(self.image_paths) self.true_len = len(self.image_paths)
print(f"Loaded {self.true_len} images from {root}") print(f"Loaded {self.true_len} images with labels from {root}")
def _load_annotations(self, annotations_file: str) -> None:
"""
Load annotations from the specified file.
Args:
annotations_file (str): Path to the annotations file.
"""
try:
with open(annotations_file, 'r') as f:
lines = f.readlines()
for line in lines:
parts = line.strip().split()
if len(parts) >= 2:
image_id = parts[0]
class_name = parts[1]
# Add class to mapping if not already present
if class_name not in self.class_to_idx:
idx = len(self.class_to_idx)
self.class_to_idx[class_name] = idx
self.idx_to_class[idx] = class_name
# Store label for this image
self.labels[image_id] = self.class_to_idx[class_name]
print(f"Loaded {len(self.labels)} annotations with {len(self.class_to_idx)} unique classes")
except Exception as e:
print(f"Error loading annotations: {e}")
raise
def _get_image_id(self, filepath: Path) -> str:
"""
Extract image ID from filepath.
Args:
filepath (Path): Path to the image file.
Returns:
str: Image ID (filename without extension).
"""
return filepath.stem
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, str]: def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, str]:
""" """
@ -57,7 +110,7 @@ class ADK20Dataset(Dataset):
""" """
adjusted_index = index % self.true_len # Avoid division by zero error adjusted_index = index % self.true_len # Avoid division by zero error
filepath = str(self.image_paths[adjusted_index]) filepath = str(self.image_paths[adjusted_index])
# print("filepath:", filepath)
try: try:
image = Image.open(filepath).convert("RGB") image = Image.open(filepath).convert("RGB")
except Exception as e: except Exception as e:
@ -67,19 +120,18 @@ class ADK20Dataset(Dataset):
if self.transform: if self.transform:
image = self.transform(image) image = self.transform(image)
target = torch.zeros((1,)) # Modify if ADK20 has labels # Get label for this image
image_id = self._get_image_id(Path(filepath))
if image_id in self.labels:
target = torch.tensor(self.labels[image_id])
else:
# Use -1 as label for images without annotations
target = torch.tensor(-1)
if self.target_transform: if self.target_transform:
target = self.target_transform(target) target = self.target_transform(target)
# raw_image = rgb_to_raw(filepath)
# after_raw = raw_to_rgb(raw_image)
# print("Img:", image)
# print(type(image), type(raw_image))
# print(type(image), type(after_raw), image.keys())
return image, target, filepath return image, target, filepath
# return raw_image, target, filepath
# return image, raw_image, target, filepath
def __len__(self) -> int: def __len__(self) -> int:
return self.true_len return self.true_len
@ -113,5 +165,29 @@ class ADK20Dataset(Dataset):
} }
# print("Type: ", type(rgb_to_raw)) # print("Type: ", type(rgb_to_raw))
return output return output
def get_targets(self) -> np.ndarray:
"""
Returns target labels for all dataset samples.
Returns:
np.ndarray: Array of class indices for each sample.
"""
targets = []
for path in self.image_paths:
image_id = self._get_image_id(path)
if image_id in self.labels:
targets.append(self.labels[image_id])
else:
targets.append(-1) # Use -1 for unknown labels
return np.array(targets, dtype=np.int64)
def get_classes(self) -> List[str]:
"""
Returns the list of class names.
Returns:
List[str]: List of class names.
"""
return [self.idx_to_class[i] for i in range(len(self.idx_to_class))]

View File

@ -0,0 +1,356 @@
# -*- coding: utf-8 -*-
"""Input_level_adapter.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/167em6SRMZhY0L7Lo-zY2ibkV0qQ7kMWJ
"""
import torch.nn as nn
import os
import torch
from torch import Tensor
from typing import Optional, Tuple, List
from torch.nn.functional import grid_sample, conv2d, interpolate, pad as torch_pad
class Kernel_Predictor(nn.Module):
def __init__(self, dim, mode='low', num_heads=1, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
# Query Adaptive Learning (QAL)
self.q = nn.Parameter(torch.rand((1, 4, dim)), requires_grad=True)
self.kv_downsample = nn.Sequential(
nn.Conv2d(3, dim // 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
nn.BatchNorm2d(dim // 8),
nn.GELU(),
nn.Conv2d(dim // 8, dim // 4, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
nn.BatchNorm2d(dim // 4),
nn.GELU(),
nn.Conv2d(dim // 4, dim // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
nn.BatchNorm2d(dim // 2),
nn.GELU(),
nn.Conv2d(dim // 2, dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
nn.BatchNorm2d(dim),
)
self.k = nn.Linear(dim, dim, bias=qkv_bias)
self.v = nn.Linear(dim, dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.down = nn.Linear(dim, 1)
self.softmax = nn.Softmax(dim=2)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
# Basic Parameters Number
if mode == 'low':
self.gain_base = nn.Parameter(torch.FloatTensor([3]), requires_grad=True)
else:
self.gain_base = nn.Parameter(torch.FloatTensor([1]), requires_grad=True)
self.r1_base = nn.Parameter(torch.FloatTensor([3]), requires_grad=False)
self.r2_base = nn.Parameter(torch.FloatTensor([2]), requires_grad=False)
def forward(self, x):
d_x = self.kv_downsample(x).flatten(2).transpose(1, 2)
B, N, C = d_x.shape
k = self.k(d_x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
v = self.v(d_x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
q = self.q.expand(B, -1, -1).view(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
out = (attn @ v).transpose(1, 2).reshape(B, 4, C)
out = self.proj(out)
out = self.proj_drop(out)
out = self.down(out).squeeze(-1)
out = torch.unbind(out, 1)
r1, r2, gain, sigma = out[0], out[1], out[2], out[3]
r1 = 0.1 * r1 + self.r1_base
r2 = 0.1 * r2 + self.r2_base
gain =gain + self.gain_base
return r1, r2, gain, self.sigmoid(sigma)
class Matrix_Predictor(nn.Module):
def __init__(self, dim, num_heads=1, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
# Query Adaptive Learning (QAL)
self.q = nn.Parameter(torch.rand((1, 9 + 1, dim)), requires_grad=True)
self.kv_downsample = nn.Sequential(
nn.Conv2d(3, dim // 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
nn.BatchNorm2d(dim // 8),
nn.GELU(),
nn.Conv2d(dim // 8, dim // 4, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
nn.BatchNorm2d(dim // 4),
nn.GELU(),
nn.Conv2d(dim // 4, dim // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
nn.BatchNorm2d(dim // 2),
nn.GELU(),
nn.Conv2d(dim // 2, dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
nn.BatchNorm2d(dim),
)
self.k = nn.Linear(dim, dim, bias=qkv_bias)
self.v = nn.Linear(dim, dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.down = nn.Linear(dim, 1)
self.softmax = nn.Softmax(dim=2)
self.relu = nn.ReLU()
self.ccm_base = nn.Parameter(torch.eye(3), requires_grad=False)
def forward(self, x):
d_x = self.kv_downsample(x).flatten(2).transpose(1, 2)
B, N, C = d_x.shape
k = self.k(d_x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
v = self.v(d_x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
q = self.q.expand(B, -1, -1).view(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
out = (attn @ v).transpose(1, 2).reshape(B, 9 + 1, C)
out = self.proj(out)
out = self.proj_drop(out)
out = self.down(out)
out, distance = out[:, :9, :], out[:, 9:, :].squeeze(-1)
out = out.view(B, 3, 3)
# print(self.ccm_base)
# print(out)
ccm_matrix = 0.1 * out + self.ccm_base
distance = self.relu(distance) + 1
return ccm_matrix, distance
class NILUT(nn.Module):
"""
Simple residual coordinate-based neural network for fitting 3D LUTs
Official code: https://github.com/mv-lab/nilut
"""
def __init__(self, in_features=3, hidden_features=32, hidden_layers=3, out_features=3, res=True):
super().__init__()
self.res = res
self.net = []
self.net.append(nn.Linear(in_features, hidden_features))
self.net.append(nn.ReLU())
for _ in range(hidden_layers):
self.net.append(nn.Linear(hidden_features, hidden_features))
self.net.append(nn.Tanh())
self.net.append(nn.Linear(hidden_features, out_features))
if not self.res:
self.net.append(torch.nn.Sigmoid())
self.net = nn.Sequential(*self.net)
def forward(self, intensity):
output = self.net(intensity)
if self.res:
output = output + intensity
output = torch.clamp(output, 0.,1.)
return output
def _assert_image_tensor(img: Tensor) -> None:
if not img.ndim >= 2:
raise TypeError("Tensor is not a torch image.")
def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor:
ksize_half = (kernel_size - 1) * 0.5
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size).to(sigma.device)
#print(x.device)
#print(sigma.device)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
kernel1d = pdf / pdf.sum()
return kernel1d
def _get_gaussian_kernel2d(
kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
) -> Tensor:
kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype)
kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype)
kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])
return kernel2d
def _cast_squeeze_in(img: Tensor, req_dtypes: List[torch.dtype]) -> Tuple[Tensor, bool, bool, torch.dtype]:
need_squeeze = False
# make image NCHW
if img.ndim < 4:
img = img.unsqueeze(dim=0)
need_squeeze = True
out_dtype = img.dtype
need_cast = False
if out_dtype not in req_dtypes:
need_cast = True
req_dtype = req_dtypes[0]
img = img.to(req_dtype)
return img, need_cast, need_squeeze, out_dtype
def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtype: torch.dtype) -> Tensor:
if need_squeeze:
img = img.squeeze(dim=0)
if need_cast:
if out_dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
# it is better to round before cast
img = torch.round(img)
img = img.to(out_dtype)
return img
def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor:
if not (isinstance(img, torch.Tensor)):
raise TypeError(f"img should be Tensor. Got {type(img)}")
_assert_image_tensor(img)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device)
kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(
img,
[
kernel.dtype,
],
)
# padding = (left, right, top, bottom)
padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
img = torch_pad(img, padding, mode="reflect")
img = conv2d(img, kernel, groups=img.shape[-3])
img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
return img
def Gain_Denoise(I1, r1, r2, gain, sigma, k_size=3): # [9, 9] in LOD dataset, [3, 3] in other dataset
out = []
for i in range(I1.shape[0]):
I1_gain = gain[i] * I1[i,:,:,:]
blur = gaussian_blur(I1_gain, \
[k_size, k_size], \
[r1[i], r2[i]])
sharp = blur + sigma[i] * (I1[i,:,:,:] - blur)
out.append(sharp)
return torch.stack([out[i] for i in range(I1.shape[0])], dim=0)
# Shades of Gray and Colour Constancy (Graham D. Finlayson, Elisabetta Trezzi)
def SoG_algo(img, p=1):
# https://library.imaging.org/admin/apis/public/api/ist/website/downloadArticle/cic/12/1/art00008
img = img.permute(1,2,0) # (C,H,W) --> (H,W,C)
img_P = torch.pow(img, p)
R_avg = torch.mean(img_P[:,:,0]) ** (1/p)
G_avg = torch.mean(img_P[:,:,1]) ** (1/p)
B_avg = torch.mean(img_P[:,:,2]) ** (1/p)
Avg = torch.mean(img_P) ** (1/p)
R_avg = R_avg / Avg
G_avg = G_avg / Avg
B_avg = B_avg / Avg
img_out = torch.stack([img[:,:,0]/R_avg, img[:,:,1]/G_avg, img[:,:,2]/B_avg], dim=-1)
return img_out
def WB_CCM(I2, ccm_matrix, distance):
out_I3 = []
out_I4 = []
for i in range(I2.shape[0]):
# SOG White Balance Algorithm
I3 = SoG_algo(I2[i,:,:,:], distance[i])
# Camera Color Matrix
I4 = torch.tensordot(I3, ccm_matrix[i,:,:], dims=[[-1], [-1]])
I4 = torch.clamp(I4, 1e-5, 1.0)
out_I3.append(I3)
out_I4.append(I4)
return torch.stack([out_I3[i] for i in range(I2.shape[0])], dim=0), \
torch.stack([out_I4[i] for i in range(I2.shape[0])], dim=0)
class Input_level_Adapeter(nn.Module):
def __init__(self, mode='normal', lut_dim=32, out='all', k_size=3, w_lut=True):
super(Input_level_Adapeter, self).__init__()
'''
mode: normal (for normal & over-exposure conditions) or low (for low-light conditions)
lut_dim: implicit neural look-up table dim number
out: if all, return I1, I2, I3, I4, I5, if not all, only return I5
k_size: denosing kernel size, must be odd number, we set it to 9 in LOD dataset and 3 in other dataset
w_lut: with or without implicit 3D Look-up Table
'''
self.Predictor_K = Kernel_Predictor(dim=64, mode=mode)
self.Predictor_M = Matrix_Predictor(dim=64)
self.w_lut = w_lut
if self.w_lut:
self.LUT = NILUT(hidden_features=lut_dim)
self.out = out
self.k_size = k_size
def forward(self, I1):
# (1). I1 --> I2: Denoise & Enhancement & Sharpen
r1, r2, gain, sigma = self.Predictor_K(I1)
I2 = Gain_Denoise(I1, r1, r2, gain, sigma, k_size=self.k_size) # (B,C,H,W)
I2 = torch.clamp(I2, 1e-5, 1.0) # normal & over-exposure
ccm_matrix, distance = self.Predictor_M(I2)
# (2). I2 --> I3: White Balance, Shade of Gray
# (3). I3 --> I4: Camera Colour Matrix Transformation
I3, I4 = WB_CCM(I2, ccm_matrix, distance) # (B,H,W,C)
if self.w_lut:
# (4). I4 --> I5: Implicit Neural LUT
I5 = self.LUT(I4).permute(0,3,1,2)
if self.out == 'all': # return all features
return [I1, I2, I3.permute(0,3,1,2), I4.permute(0,3,1,2), I5]
else: # only return I5
return [I5]
else:
if self.out == 'all':
return [I1, I2, I3.permute(0,3,1,2), I4.permute(0,3,1,2)]
else:
return [I4.permute(0,3,1,2)]
if __name__ == "__main__":
os.environ['CUDA_VISIBLE_DEVICES']='4'
input = torch.rand([4,3,512,512])
net = Input_level_Adapeter(out='all', w_lut=False)
out = net(input)