Fiddling with iterator wrapping for HF ds streaming

This commit is contained in:
Ross Wightman 2024-01-09 12:41:54 -08:00
parent 992976f007
commit 2eac2f6955

View File

@ -167,9 +167,11 @@ class ReaderHfids(Reader):
target_sample_count = self._num_samples_per_worker()
sample_count = 0
ds_iter = iter(self.ds)
if self.is_training:
ds_iter = chain.from_iterable(repeat(ds_iter))
ds_iter = chain.from_iterable(repeat(self.ds))
else:
ds_iter = iter(self.ds)
for sample in ds_iter:
input_data: Image.Image = sample[self.input_key]
if self.input_img_mode and input_data.mode != self.input_img_mode: