mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Fix offline_evaluate index error (#630)
* [Fix] Fix offline eval dataset index error. * update * update
This commit is contained in:
parent
f2b0540f58
commit
a9a575866f
@ -104,12 +104,6 @@ class Evaluator:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# support chunking iterable objects
|
# support chunking iterable objects
|
||||||
if data is not None:
|
|
||||||
assert len(data_samples) == len(data), (
|
|
||||||
'outputs and data should have the same length, but got '
|
|
||||||
f'outputs length: {len(data_samples)} '
|
|
||||||
f'data length: {len(data)}')
|
|
||||||
|
|
||||||
def get_chunks(seq: Iterator, chunk_size=1):
|
def get_chunks(seq: Iterator, chunk_size=1):
|
||||||
stop = False
|
stop = False
|
||||||
while not stop:
|
while not stop:
|
||||||
@ -123,10 +117,17 @@ class Evaluator:
|
|||||||
if chunk:
|
if chunk:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
if data is not None:
|
||||||
|
assert len(data_samples) == len(data), (
|
||||||
|
'data_samples and data should have the same length, but got '
|
||||||
|
f'data_samples length: {len(data_samples)} '
|
||||||
|
f'data length: {len(data)}')
|
||||||
|
data = get_chunks(iter(data), chunk_size)
|
||||||
|
|
||||||
size = 0
|
size = 0
|
||||||
for output_chunk in get_chunks(iter(data_samples), chunk_size):
|
for output_chunk in get_chunks(iter(data_samples), chunk_size):
|
||||||
if data is not None:
|
if data is not None:
|
||||||
data_chunk = pseudo_collate(data[size:size + chunk_size])
|
data_chunk = pseudo_collate(next(data)) # type: ignore
|
||||||
else:
|
else:
|
||||||
data_chunk = None
|
data_chunk = None
|
||||||
size += len(output_chunk)
|
size += len(output_chunk)
|
||||||
|
@ -247,7 +247,7 @@ class TestEvaluator(TestCase):
|
|||||||
all_data = [dict() for _ in range(9)]
|
all_data = [dict() for _ in range(9)]
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
AssertionError,
|
AssertionError,
|
||||||
'outputs and data should have the same length'):
|
'data_samples and data should have the same length'):
|
||||||
evaluator.offline_evaluate(all_predictions, all_data)
|
evaluator.offline_evaluate(all_predictions, all_data)
|
||||||
|
|
||||||
@unittest.skipUnless(torch.cuda.is_available(), 'can only run with gpu')
|
@unittest.skipUnless(torch.cuda.is_available(), 'can only run with gpu')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user