fix stuff
parent
9d6d6d42b5
commit
279ec583f9
|
@ -38,22 +38,25 @@ class ImageDataset(Dataset):
|
||||||
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
def _retrieve_images(self, path, is_valid=False, preserve=False, frac=1):
|
def _retrieve_images(self, path, is_valid=True, preserve=False, frac=1):
|
||||||
images_ini = len(self.preserved_images)
|
images_ini = len(self.preserved_images)
|
||||||
images = []
|
images = []
|
||||||
for root, _, files in os.walk(path):
|
for root, _, files in os.walk(path):
|
||||||
images_dir = []
|
images_dir = []
|
||||||
for file in files:
|
for file in files:
|
||||||
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff')):
|
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff')):
|
||||||
|
im = os.path.join(root, file)
|
||||||
if is_valid:
|
if is_valid:
|
||||||
try:
|
try:
|
||||||
Image.open(os.path.join(root, file))
|
with open(im, 'rb') as f:
|
||||||
images_dir.append(os.path.join(root, file))
|
image_data = f.read()
|
||||||
|
ImageDataDecoder(image_data).decode()
|
||||||
|
images_dir.append(im)
|
||||||
|
|
||||||
except OSError:
|
except OSError:
|
||||||
print(f"Image at path {os.path.join(root, file)} could not be opened.")
|
print(f"Image at path {im} could not be opened.")
|
||||||
else:
|
else:
|
||||||
images_dir.append(os.path.join(root, file))
|
images_dir.append(im)
|
||||||
|
|
||||||
if preserve:
|
if preserve:
|
||||||
random.seed(24)
|
random.seed(24)
|
||||||
|
@ -62,6 +65,9 @@ class ImageDataset(Dataset):
|
||||||
self.preserved_images.extend(images_dir[:split_index])
|
self.preserved_images.extend(images_dir[:split_index])
|
||||||
images.extend(images_dir[split_index:])
|
images.extend(images_dir[split_index:])
|
||||||
|
|
||||||
|
else:
|
||||||
|
images.extend(images_dir)
|
||||||
|
|
||||||
images_end = len(self.preserved_images)
|
images_end = len(self.preserved_images)
|
||||||
if preserve:
|
if preserve:
|
||||||
print(f"{images_end - images_ini} images have been saved for the dataset at path {path}")
|
print(f"{images_end - images_ini} images have been saved for the dataset at path {path}")
|
||||||
|
|
|
@ -200,7 +200,7 @@ def do_train(cfg, model, resume=False):
|
||||||
# save the preserved images
|
# save the preserved images
|
||||||
|
|
||||||
if dataset.preserved_images:
|
if dataset.preserved_images:
|
||||||
write_list(os.path.join(cfg.train.output_dir, 'preserved_images.pkl'))
|
write_list(os.path.join(cfg.train.output_dir, 'preserved_images.pkl'), dataset.preserved_images)
|
||||||
|
|
||||||
# sampler_type = SamplerType.INFINITE
|
# sampler_type = SamplerType.INFINITE
|
||||||
sampler_type = SamplerType.SHARDED_INFINITE # define the sampler to use for fsdp
|
sampler_type = SamplerType.SHARDED_INFINITE # define the sampler to use for fsdp
|
||||||
|
|
Loading…
Reference in New Issue