fix stuff

pull/422/head
Etienne Guevel 2024-05-26 01:00:55 +02:00
parent 9d6d6d42b5
commit 279ec583f9
2 changed files with 12 additions and 6 deletions

View File

@ -38,22 +38,25 @@ class ImageDataset(Dataset):
return images return images
def _retrieve_images(self, path, is_valid=False, preserve=False, frac=1): def _retrieve_images(self, path, is_valid=True, preserve=False, frac=1):
images_ini = len(self.preserved_images) images_ini = len(self.preserved_images)
images = [] images = []
for root, _, files in os.walk(path): for root, _, files in os.walk(path):
images_dir = [] images_dir = []
for file in files: for file in files:
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff')): if file.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff')):
im = os.path.join(root, file)
if is_valid: if is_valid:
try: try:
Image.open(os.path.join(root, file)) with open(im, 'rb') as f:
images_dir.append(os.path.join(root, file)) image_data = f.read()
ImageDataDecoder(image_data).decode()
images_dir.append(im)
except OSError: except OSError:
print(f"Image at path {os.path.join(root, file)} could not be opened.") print(f"Image at path {im} could not be opened.")
else: else:
images_dir.append(os.path.join(root, file)) images_dir.append(im)
if preserve: if preserve:
random.seed(24) random.seed(24)
@ -62,6 +65,9 @@ class ImageDataset(Dataset):
self.preserved_images.extend(images_dir[:split_index]) self.preserved_images.extend(images_dir[:split_index])
images.extend(images_dir[split_index:]) images.extend(images_dir[split_index:])
else:
images.extend(images_dir)
images_end = len(self.preserved_images) images_end = len(self.preserved_images)
if preserve: if preserve:
print(f"{images_end - images_ini} images have been saved for the dataset at path {path}") print(f"{images_end - images_ini} images have been saved for the dataset at path {path}")

View File

@ -200,7 +200,7 @@ def do_train(cfg, model, resume=False):
# save the preserved images # save the preserved images
if dataset.preserved_images: if dataset.preserved_images:
write_list(os.path.join(cfg.train.output_dir, 'preserved_images.pkl')) write_list(os.path.join(cfg.train.output_dir, 'preserved_images.pkl'), dataset.preserved_images)
# sampler_type = SamplerType.INFINITE # sampler_type = SamplerType.INFINITE
sampler_type = SamplerType.SHARDED_INFINITE # define the sampler to use for fsdp sampler_type = SamplerType.SHARDED_INFINITE # define the sampler to use for fsdp