[Fix] Fix offline_evaluate index error (#630)

* [Fix] Fix offline eval dataset index error.

* update

* update
This commit is contained in:
RangiLyu 2022-10-21 16:53:57 +08:00 committed by Zaida Zhou
parent f2b0540f58
commit a9a575866f
2 changed files with 9 additions and 8 deletions

View File

@ -104,12 +104,6 @@ class Evaluator:
""" """
# support chunking iterable objects # support chunking iterable objects
if data is not None:
assert len(data_samples) == len(data), (
'outputs and data should have the same length, but got '
f'outputs length: {len(data_samples)} '
f'data length: {len(data)}')
def get_chunks(seq: Iterator, chunk_size=1): def get_chunks(seq: Iterator, chunk_size=1):
stop = False stop = False
while not stop: while not stop:
@ -123,10 +117,17 @@ class Evaluator:
if chunk: if chunk:
yield chunk yield chunk
if data is not None:
assert len(data_samples) == len(data), (
'data_samples and data should have the same length, but got '
f'data_samples length: {len(data_samples)} '
f'data length: {len(data)}')
data = get_chunks(iter(data), chunk_size)
size = 0 size = 0
for output_chunk in get_chunks(iter(data_samples), chunk_size): for output_chunk in get_chunks(iter(data_samples), chunk_size):
if data is not None: if data is not None:
data_chunk = pseudo_collate(data[size:size + chunk_size]) data_chunk = pseudo_collate(next(data)) # type: ignore
else: else:
data_chunk = None data_chunk = None
size += len(output_chunk) size += len(output_chunk)

View File

@ -247,7 +247,7 @@ class TestEvaluator(TestCase):
all_data = [dict() for _ in range(9)] all_data = [dict() for _ in range(9)]
with self.assertRaisesRegex( with self.assertRaisesRegex(
AssertionError, AssertionError,
'outputs and data should have the same length'): 'data_samples and data should have the same length'):
evaluator.offline_evaluate(all_predictions, all_data) evaluator.offline_evaluate(all_predictions, all_data)
@unittest.skipUnless(torch.cuda.is_available(), 'can only run with gpu') @unittest.skipUnless(torch.cuda.is_available(), 'can only run with gpu')