mirror of
https://github.com/open-mmlab/mmselfsup.git
synced 2025-06-03 14:59:38 +08:00
[Fix]: Fix distributed setting for shape bias (#689)
* [Fix]: Fix dist bug * [Fix]: Fix lint * [Feature]: Add download log info * [Fix]: Fix lint
This commit is contained in:
parent
0f634de85b
commit
d3f57edf79
@ -6,6 +6,7 @@ from typing import List, Sequence
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from mmengine.dist.utils import get_rank
|
||||||
from mmengine.evaluator import BaseMetric
|
from mmengine.evaluator import BaseMetric
|
||||||
|
|
||||||
from mmselfsup.registry import METRICS
|
from mmselfsup.registry import METRICS
|
||||||
@ -75,7 +76,8 @@ class ShapeBiasMetric(BaseMetric):
|
|||||||
self.csv_dir = csv_dir
|
self.csv_dir = csv_dir
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.dataset_name = dataset_name
|
self.dataset_name = dataset_name
|
||||||
self.csv_path = self.create_csv()
|
if get_rank() == 0:
|
||||||
|
self.csv_path = self.create_csv()
|
||||||
|
|
||||||
def process(self, data_batch, data_samples: Sequence[dict]) -> None:
|
def process(self, data_batch, data_samples: Sequence[dict]) -> None:
|
||||||
"""Process one batch of data samples.
|
"""Process one batch of data samples.
|
||||||
@ -126,7 +128,7 @@ class ShapeBiasMetric(BaseMetric):
|
|||||||
os.remove(csv_path)
|
os.remove(csv_path)
|
||||||
directory = osp.dirname(csv_path)
|
directory = osp.dirname(csv_path)
|
||||||
if not osp.exists(directory):
|
if not osp.exists(directory):
|
||||||
os.makedirs(directory)
|
os.makedirs(directory, exist_ok=True)
|
||||||
with open(csv_path, 'w') as f:
|
with open(csv_path, 'w') as f:
|
||||||
writer = csv.writer(f)
|
writer = csv.writer(f)
|
||||||
writer.writerow([
|
writer.writerow([
|
||||||
@ -161,7 +163,8 @@ class ShapeBiasMetric(BaseMetric):
|
|||||||
Returns:
|
Returns:
|
||||||
dict: A dict of metrics.
|
dict: A dict of metrics.
|
||||||
"""
|
"""
|
||||||
self.dump_results_to_csv(results)
|
if get_rank() == 0:
|
||||||
|
self.dump_results_to_csv(results)
|
||||||
metrics = dict()
|
metrics = dict()
|
||||||
metrics['accuracy/top1'] = np.mean([
|
metrics['accuracy/top1'] = np.mean([
|
||||||
result['pred_category'][0] == result['gt_category']
|
result['pred_category'][0] == result['gt_category']
|
||||||
|
@ -8,6 +8,7 @@ import matplotlib as mpl
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
from mmengine.logging import MMLogger
|
||||||
from utils import FormatStrFormatter, ShapeBias
|
from utils import FormatStrFormatter, ShapeBias
|
||||||
|
|
||||||
# global default boundary settings for thin gray transparent
|
# global default boundary settings for thin gray transparent
|
||||||
@ -46,12 +47,24 @@ parser.add_argument(
|
|||||||
help= # noqa
|
help= # noqa
|
||||||
'the plotting names for the plots of each model, and they should be in the same order as model_names' # noqa: E501
|
'the plotting names for the plots of each model, and they should be in the same order as model_names' # noqa: E501
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--delete-icons',
|
||||||
|
action='store_true',
|
||||||
|
help='whether to delete the icons after plotting')
|
||||||
|
|
||||||
humans = [
|
humans = [
|
||||||
'subject-01', 'subject-02', 'subject-03', 'subject-04', 'subject-05',
|
'subject-01', 'subject-02', 'subject-03', 'subject-04', 'subject-05',
|
||||||
'subject-06', 'subject-07', 'subject-08', 'subject-09', 'subject-10'
|
'subject-06', 'subject-07', 'subject-08', 'subject-09', 'subject-10'
|
||||||
]
|
]
|
||||||
|
|
||||||
|
icon_names = [
|
||||||
|
'airplane.png', 'response_icons_vertical_reverse.png', 'bottle.png',
|
||||||
|
'car.png', 'oven.png', 'elephant.png', 'dog.png', 'boat.png', 'clock.png',
|
||||||
|
'chair.png', 'keyboard.png', 'bird.png', 'bicycle.png',
|
||||||
|
'response_icons_horizontal.png', 'cat.png', 'bear.png', 'colorbar.pdf',
|
||||||
|
'knife.png', 'response_icons_vertical.png', 'truck.png'
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def read_csvs(csv_dir: str) -> pd.DataFrame:
|
def read_csvs(csv_dir: str) -> pd.DataFrame:
|
||||||
"""Reads all csv files in a directory and returns a single dataframe.
|
"""Reads all csv files in a directory and returns a single dataframe.
|
||||||
@ -83,7 +96,6 @@ def plot_shape_bias_matrixplot(args, analysis=ShapeBias()) -> None:
|
|||||||
mpl.rcParams['font.serif'] = ['Times New Roman']
|
mpl.rcParams['font.serif'] = ['Times New Roman']
|
||||||
|
|
||||||
plt.figure(figsize=(9, 7))
|
plt.figure(figsize=(9, 7))
|
||||||
|
|
||||||
df = read_csvs(args.csv_dir)
|
df = read_csvs(args.csv_dir)
|
||||||
|
|
||||||
fontsize = 15
|
fontsize = 15
|
||||||
@ -218,19 +230,27 @@ def plot_shape_bias_matrixplot(args, analysis=ShapeBias()) -> None:
|
|||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
def check_icons() -> bool:
|
||||||
icon_names = [
|
"""Check if icons are present, if not download them."""
|
||||||
'airplane.png', 'response_icons_vertical_reverse.png', 'bottle.png',
|
if not osp.exists(ICONS_DIR):
|
||||||
'car.png', 'oven.png', 'elephant.png', 'dog.png', 'boat.png',
|
return False
|
||||||
'clock.png', 'chair.png', 'keyboard.png', 'bird.png', 'bicycle.png',
|
|
||||||
'response_icons_horizontal.png', 'cat.png', 'bear.png', 'colorbar.pdf',
|
|
||||||
'knife.png', 'response_icons_vertical.png', 'truck.png'
|
|
||||||
]
|
|
||||||
root_url = 'https://github.com/bethgelab/model-vs-human/raw/master/assets/icons' # noqa: E501
|
|
||||||
os.makedirs(ICONS_DIR, exist_ok=True)
|
|
||||||
for icon_name in icon_names:
|
for icon_name in icon_names:
|
||||||
url = osp.join(root_url, icon_name)
|
if not osp.exists(osp.join(ICONS_DIR, icon_name)):
|
||||||
os.system('wget -O {} {}'.format(osp.join(ICONS_DIR, icon_name), url))
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
if not check_icons():
|
||||||
|
root_url = 'https://github.com/bethgelab/model-vs-human/raw/master/assets/icons' # noqa: E501
|
||||||
|
os.makedirs(ICONS_DIR, exist_ok=True)
|
||||||
|
MMLogger.get_current_instance().info(
|
||||||
|
f'Downloading icons to {ICONS_DIR}')
|
||||||
|
for icon_name in icon_names:
|
||||||
|
url = osp.join(root_url, icon_name)
|
||||||
|
os.system('wget -O {} {}'.format(
|
||||||
|
osp.join(ICONS_DIR, icon_name), url))
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
assert len(args.model_names) * 3 == len(args.colors), 'Number of colors \
|
assert len(args.model_names) * 3 == len(args.colors), 'Number of colors \
|
||||||
@ -260,5 +280,5 @@ if __name__ == '__main__':
|
|||||||
args.plotting_names.append('Humans')
|
args.plotting_names.append('Humans')
|
||||||
|
|
||||||
plot_shape_bias_matrixplot(args)
|
plot_shape_bias_matrixplot(args)
|
||||||
|
if args.delete_icons:
|
||||||
os.system('rm -rf {}'.format(ICONS_DIR))
|
os.system('rm -rf {}'.format(ICONS_DIR))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user