new dataset classes

pull/511/head
Veronikkkka 2025-03-26 12:51:56 +00:00
parent 3784ec00ce
commit 3d03635c63
5 changed files with 302 additions and 7 deletions

View File

@ -86,8 +86,9 @@ data_transform: "default"
train:
batch_size_per_gpu: 16 #vitg 26+, vitl: 56, vits:152, vitb:120 for 8 node
num_workers: 1
OFFICIAL_EPOCH_LENGTH: 1000 # 1250
dataset_path: ImageNet:root=/home/paperspace/Documents/nika_space/ADE20K/ADEChallengeData2016/images/training_raw/ #MIT:root=/home/paperspace/Documents/nika_space/mit_dataset/train/
OFFICIAL_EPOCH_LENGTH: 500 # 1250
dataset_path: RAISE:root=/home/paperspace/Documents/nika_space/RAISE_383.csv
#RAW_NOD:root=/home/paperspace/Documents/nika_space/raw_nod_dataset/RAW-NOD/image/ #ImageNet:root=/home/paperspace/Documents/nika_space/ADE20K/ADEChallengeData2016/images/training_raw/
centering: sinkhorn_knopp
drop_path_rate: 0.4
@ -100,7 +101,7 @@ teacher:
momentum_teacher: 0.994
optim:
epochs: 50 # 500
weight_decay_end: 0.3
weight_decay_end: 0.35
base_lr: 0.0001 # learning rate for a batch size of 1024
warmup_epochs: 20 # 80
layerwise_decay: 1.0
@ -114,9 +115,9 @@ evaluation:
# "dinov2_vits14","dinov2_vitb14","dinov2_vitl14","dinov2_vitg14"
student:
arch: dinov2_vitb14
arch: dinov2_vitb14 #vit_base #dinov2_vitb14
patch_size: 14
merge_block_indexes: "" # num, num, num,
merge_block_indexes: "0, 6, 10" # num, num, num,
crops:
global_crops_scale:
- 0.32 #0.32 default

View File

@ -5,4 +5,6 @@
from .image_net import ImageNet
from .image_net_22k import ImageNet22k
from .my_dataset import ADK20Dataset
from .my_dataset import ADK20Dataset
from .raw_nod import RAWNODDataset
from .raise_dataset import RaiseDataset

View File

@ -0,0 +1,104 @@
import os
import csv
import requests
import json
from pathlib import Path
from typing import Optional, Callable, List, Tuple
from PIL import Image
from torch.utils.data import Dataset
import os
import csv
import requests
from pathlib import Path
from typing import Callable, List, Optional, Tuple
from PIL import Image
from torch.utils.data import Dataset
class RaiseDataset(Dataset):
def __init__(
self,
root: str,
download_dir: str = "./raw_nod_images",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
):
"""
RAISE Dataset for image classification.
Args:
root (str): Path to the dataset CSV file.
download_dir (str): Directory where images will be stored.
transform (Callable, optional): Image transformations.
target_transform (Callable, optional): Transformations for labels.
"""
self.download_dir = Path(download_dir)
self.transform = transform
self.target_transform = target_transform
self.download_dir.mkdir(parents=True, exist_ok=True)
self.image_info = []
self._load_csv(root)
def _load_csv(self, csv_file: str) -> None:
"""Loads dataset from the CSV file and downloads images if missing."""
with open(csv_file, "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
i = 0
for row in reader:
if i > 100:
break
nef_url = row["NEF"]
file_name = os.path.basename(nef_url)
file_path = self.download_dir / file_name
# Get labels (last and second-to-last columns)
labels = [row["Keywords"], row["Scene Mode"]]
self.image_info.append({
"nef_url": nef_url,
"file_path": file_path,
"labels": labels
})
# Download image if not already present
if not file_path.exists():
self._download_image(nef_url, file_path)
i += 1
print(f"Dataset loaded: {len(self.image_info)} images.")
def _download_image(self, url: str, file_path: Path) -> None:
"""Downloads an image from a URL."""
try:
response = requests.get(url, stream=True)
response.raise_for_status()
with open(file_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"Downloaded: {file_path.name}")
except requests.RequestException as e:
print(f"Failed to download {url}: {e}")
def __len__(self) -> int:
return len(self.image_info)
def __getitem__(self, idx: int) -> Tuple[Image.Image, List[str]]:
"""Returns image and corresponding labels."""
info = self.image_info[idx]
image = Image.open(info["file_path"]).convert("RGB")
labels = info["labels"]
# Apply transformations
if self.transform:
image = self.transform(image)
if self.target_transform:
labels = self.target_transform(labels)
return image, labels
def get_targets(self) -> List[List[str]]:
"""Returns all labels in dataset."""
return [info["labels"] for info in self.image_info]

View File

@ -0,0 +1,184 @@
import os
import random
from pathlib import Path
from typing import Callable, Optional, Tuple, List
import torch
from torch.utils.data import Dataset
import numpy as np
import cv2
from PIL import Image
import json
import rawpy
class RAWNODDataset(Dataset):
def __init__(
self,
root: str,
annotations_file: str = "/home/paperspace/Documents/nika_space/raw_nod_dataset/RAW-NOD/annotations/Nikon/raw_new_Nikon750_train.json",
transforms: Optional[Callable] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
shuffle: bool = False,
) -> None:
"""
RAW-NOD Dataset for image classification with labels.
Args:
root (str): Path to dataset directory.
annotations_file (str): Path to the annotations file containing image IDs and labels.
transforms (Callable, optional): Combined image and target transformations.
transform (Callable, optional): Image transformations.
target_transform (Callable, optional): Target transformations.
shuffle (bool, optional): If True, shuffles the dataset. Defaults to False.
"""
self.root = Path(root)
self.transforms = transforms
self.transform = transform
self.target_transform = target_transform
print("root:", self.root)
# Initialize attributes
self.image_info = {} # Fix: Ensure image_info exists
self.labels = {} # Mapping of image_id → class labels
self.class_to_idx = {} # Mapping of class names → class IDs
self.idx_to_class = {} # Mapping of class IDs → class names
# Load annotations
self._load_annotations(annotations_file)
# Load image paths
self.image_paths = sorted(list(self.root.rglob("*.NEF")) + list(self.root.rglob("*.JPEG")))
if not self.image_paths:
raise ValueError(f"No images found in dataset directory: {root}")
# Filter image paths based on loaded annotations
self.image_paths = [p for p in self.image_paths if self._get_image_id(p) in self.labels]
if shuffle:
import random
random.shuffle(self.image_paths)
self.true_len = len(self.image_paths)
print(f"Loaded {self.true_len} images with labels from {root}")
def _load_annotations(self, annotation_file: str) -> None:
"""
Load COCO-style annotations from a JSON file.
Args:
annotation_file (str): Path to the COCO annotation file.
"""
try:
with open(annotation_file, "r") as f:
data = json.load(f)
self.image_info = {} # Ensure this is initialized
self.labels = {} # Reset labels
self.class_to_idx = {}
self.idx_to_class = {}
# Process category labels
for category in data["categories"]:
self.class_to_idx[category["name"]] = category["id"]
self.idx_to_class[category["id"]] = category["name"]
# Process images
for img in data["images"]:
self.image_info[img["id"]] = img # Store image metadata
self.labels[img["id"]] = [] # Initialize empty label list
# Process annotations
for anno in data["annotations"]:
img_id = anno["image_id"]
category_id = anno["category_id"]
if img_id in self.labels:
self.labels[img_id].append(category_id)
print(f"Loaded {len(self.labels)} annotations for {len(self.image_info)} images.")
except Exception as e:
print(f"Error loading annotations: {e}")
raise
def _get_image_id(self, filepath: Path) -> str:
"""
Extract image ID from filepath.
Args:
filepath (Path): Path to the image file.
Returns:
str: Image ID (filename without extension).
"""
return filepath.stem
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, str]:
"""
Loads and returns an image, target, and filepath.
Args:
index (int): Dataset index.
Returns:
Tuple[torch.Tensor, torch.Tensor, str]: (image, target, filepath)
"""
adjusted_index = index % self.true_len # Avoid division by zero error
filepath = str(self.image_paths[adjusted_index])
try:
# image = Image.open(filepath).convert("RGB")
raw = rawpy.imread(filepath)
rgb = raw.postprocess()
image = Image.fromarray(rgb)
except Exception as e:
print(f"Error loading image {filepath}: {e}")
return self.__getitem__((index + 1) % self.true_len) # Skip to next valid image
if self.transform:
image = self.transform(image)
# Get label for this image
image_id = self._get_image_id(Path(filepath))
if image_id in self.labels:
target = torch.tensor(self.labels[image_id])
else:
# Use -1 as label for images without annotations
target = torch.tensor(-1)
if self.target_transform:
target = self.target_transform(target)
return image, target, filepath
def __len__(self) -> int:
return self.true_len
def get_targets(self) -> np.ndarray:
"""
Returns target labels for all dataset samples.
Returns:
np.ndarray: Array of class indices for each sample.
"""
targets = []
for path in self.image_paths:
image_id = self._get_image_id(path)
if image_id in self.labels:
targets.append(self.labels[image_id])
else:
targets.append(-1) # Use -1 for unknown labels
return np.array(targets, dtype=np.int64)
def get_classes(self) -> List[str]:
"""
Returns the list of class names.
Returns:
List[str]: List of class names.
"""
return [self.idx_to_class[i] for i in range(len(self.idx_to_class))]

View File

@ -13,7 +13,7 @@ from torch.utils.data import Sampler
from .datasets import ImageNet, ImageNet22k
from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler
from .datasets import (
ADK20Dataset
ADK20Dataset, RAWNODDataset, RaiseDataset
)
logger = logging.getLogger("dinov2")
@ -60,6 +60,10 @@ def _parse_dataset_str(dataset_str: str):
# class_ = HemaStandardDataset
if name == "ImageNet":
class_ = ADK20Dataset
elif name == "RAW_NOD":
class_ = RAWNODDataset
elif name == "RAISE":
class_ = RaiseDataset
else:
raise ValueError(f'Unsupported dataset "{name}"')