convert RGB
parent
9e43e49463
commit
6aaa74ce49
|
@ -2,6 +2,7 @@ import os
|
|||
import pathlib
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
from .decoders import ImageDataDecoder
|
||||
from PIL import Image
|
||||
|
||||
class ImageDataset(Dataset):
|
||||
|
@ -37,7 +38,7 @@ class ImageDataset(Dataset):
|
|||
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff')):
|
||||
if is_valid:
|
||||
try:
|
||||
Image.open(os.path.join(root, file)).convert('RGB')
|
||||
Image.open(os.path.join(root, file))
|
||||
images.append(os.path.join(root, file))
|
||||
|
||||
except OSError:
|
||||
|
@ -46,15 +47,25 @@ class ImageDataset(Dataset):
|
|||
images.append(os.path.join(root, file))
|
||||
|
||||
return images
|
||||
|
||||
def get_image_data(self, index: int):
|
||||
path = self.images_list[index]
|
||||
with open(path) as f:
|
||||
image_data = f.read()
|
||||
|
||||
return image_data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images_list)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
image_path = self.images_list[idx]
|
||||
image = Image.open(image_path)
|
||||
def __getitem__(self, index: int):
|
||||
try:
|
||||
image_data = self.get_image_data(index)
|
||||
image = ImageDataDecoder(image_data).decode()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"can nor read image for sample {index}") from e
|
||||
|
||||
if self.transform:
|
||||
if self.transform is not None:
|
||||
image = self.transform(image)
|
||||
|
||||
return image
|
||||
|
|
Loading…
Reference in New Issue