fix stuff
parent
9d6d6d42b5
commit
d605692003
|
@ -2,17 +2,19 @@ import os
|
|||
import pathlib
|
||||
import random
|
||||
from typing import List
|
||||
from omegaconf.listconfig import ListConfig
|
||||
from torch.utils.data import Dataset
|
||||
from .decoders import ImageDataDecoder
|
||||
from PIL import Image
|
||||
|
||||
class ImageDataset(Dataset):
|
||||
def __init__(self, root, transform=None, path_preserved: List[str]=[], frac: float=0.1):
|
||||
def __init__(self, root, transform=None, path_preserved: List[str]=[], frac: float=0.1, is_valid=True):
|
||||
self.root = root
|
||||
self.transform = transform
|
||||
self.path_preserved = path_preserved if isinstance(path_preserved, list) else [path_preserved]
|
||||
self.path_preserved = path_preserved if isinstance(path_preserved, (list, ListConfig)) else [path_preserved]
|
||||
self.frac = frac
|
||||
self.preserved_images = []
|
||||
self.is_valid = is_valid
|
||||
self.images_list = self._get_image_list()
|
||||
|
||||
def _get_image_list(self):
|
||||
|
@ -22,7 +24,7 @@ class ImageDataset(Dataset):
|
|||
try:
|
||||
p = self.root
|
||||
preserve = p in self.path_preserved
|
||||
images.extend(self._retrieve_images(p, preserve=preserve, frac=self.frac))
|
||||
images.extend(self._retrieve_images(p, preserve=preserve, frac=self.frac, is_valid=self.is_valid))
|
||||
|
||||
except OSError:
|
||||
print("The root given is nor a list nor a path")
|
||||
|
@ -31,29 +33,32 @@ class ImageDataset(Dataset):
|
|||
for p in self.root:
|
||||
try:
|
||||
preserve = p in self.path_preserved
|
||||
images.extend(self._retrieve_images(p, preserve=preserve, frac=self.frac))
|
||||
images.extend(self._retrieve_images(p, preserve=preserve, frac=self.frac, is_valid=self.is_valid))
|
||||
|
||||
except OSError:
|
||||
print(f"the path indicated at {p} cannot be found.")
|
||||
|
||||
return images
|
||||
|
||||
def _retrieve_images(self, path, is_valid=False, preserve=False, frac=1):
|
||||
def _retrieve_images(self, path, is_valid=True, preserve=False, frac=1):
|
||||
images_ini = len(self.preserved_images)
|
||||
images = []
|
||||
for root, _, files in os.walk(path):
|
||||
images_dir = []
|
||||
for file in files:
|
||||
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff')):
|
||||
im = os.path.join(root, file)
|
||||
if is_valid:
|
||||
try:
|
||||
Image.open(os.path.join(root, file))
|
||||
images_dir.append(os.path.join(root, file))
|
||||
with open(im, 'rb') as f:
|
||||
image_data = f.read()
|
||||
ImageDataDecoder(image_data).decode()
|
||||
images_dir.append(im)
|
||||
|
||||
except OSError:
|
||||
print(f"Image at path {os.path.join(root, file)} could not be opened.")
|
||||
print(f"Image at path {im} could not be opened.")
|
||||
else:
|
||||
images_dir.append(os.path.join(root, file))
|
||||
images_dir.append(im)
|
||||
|
||||
if preserve:
|
||||
random.seed(24)
|
||||
|
@ -62,6 +67,9 @@ class ImageDataset(Dataset):
|
|||
self.preserved_images.extend(images_dir[:split_index])
|
||||
images.extend(images_dir[split_index:])
|
||||
|
||||
else:
|
||||
images.extend(images_dir)
|
||||
|
||||
images_end = len(self.preserved_images)
|
||||
if preserve:
|
||||
print(f"{images_end - images_ini} images have been saved for the dataset at path {path}")
|
||||
|
|
|
@ -200,7 +200,7 @@ def do_train(cfg, model, resume=False):
|
|||
# save the preserved images
|
||||
|
||||
if dataset.preserved_images:
|
||||
write_list(os.path.join(cfg.train.output_dir, 'preserved_images.pkl'))
|
||||
write_list(os.path.join(cfg.train.output_dir, 'preserved_images.pkl'), dataset.preserved_images)
|
||||
|
||||
# sampler_type = SamplerType.INFINITE
|
||||
sampler_type = SamplerType.SHARDED_INFINITE # define the sampler to use for fsdp
|
||||
|
|
Loading…
Reference in New Issue