[Fix] Fix offline_evaluate index error (#630)

* [Fix] Fix offline eval dataset index error.

* update

* update
This commit is contained in:
RangiLyu 2022-10-21 16:53:57 +08:00 committed by Zaida Zhou
parent f2b0540f58
commit a9a575866f
2 changed files with 9 additions and 8 deletions

View File

@ -104,12 +104,6 @@ class Evaluator:
"""
# support chunking iterable objects
if data is not None:
assert len(data_samples) == len(data), (
'outputs and data should have the same length, but got '
f'outputs length: {len(data_samples)} '
f'data length: {len(data)}')
def get_chunks(seq: Iterator, chunk_size=1):
stop = False
while not stop:
@ -123,10 +117,17 @@ class Evaluator:
if chunk:
yield chunk
if data is not None:
assert len(data_samples) == len(data), (
'data_samples and data should have the same length, but got '
f'data_samples length: {len(data_samples)} '
f'data length: {len(data)}')
data = get_chunks(iter(data), chunk_size)
size = 0
for output_chunk in get_chunks(iter(data_samples), chunk_size):
if data is not None:
data_chunk = pseudo_collate(data[size:size + chunk_size])
data_chunk = pseudo_collate(next(data)) # type: ignore
else:
data_chunk = None
size += len(output_chunk)

View File

@ -247,7 +247,7 @@ class TestEvaluator(TestCase):
all_data = [dict() for _ in range(9)]
with self.assertRaisesRegex(
AssertionError,
'outputs and data should have the same length'):
'data_samples and data should have the same length'):
evaluator.offline_evaluate(all_predictions, all_data)
@unittest.skipUnless(torch.cuda.is_available(), 'can only run with gpu')