EasyCV/easycv/hooks/swav_hook.py

52 lines
1.8 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import torch
from mmcv.runner import Hook, get_dist_info
from .registry import HOOKS
@HOOKS.register_module
class SWAVHook(Hook):
'''Hook in SWAV
'''
def __init__(self, gpu_batch_size=32, dump_path='data/', **kwargs):
self.dump_path = dump_path
self.queue_length = None
self.rank, self.world_size = get_dist_info()
self.batch_size = gpu_batch_size
if not os.path.exists(self.dump_path):
os.makedirs(self.dump_path)
def before_run(self, runner):
runner.model.module.queue = None
runner.model.module.queue_path = os.path.join(
self.dump_path, 'queue' + str(self.rank) + '.pth')
if os.path.isfile(runner.model.module.queue_path):
runner.model.module.queue = torch.load(
runner.model.module.queue_path)['queue']
# the queue needs to be divisible by the batch size
# print(type(runner.model.module))
self.queue_length = runner.model.module.config['queue_length']
self.queue_length -= self.queue_length % (
self.batch_size * self.world_size)
def before_train_epoch(self, runner):
if self.queue_length > 0 and runner.epoch >= runner.model.module.config[
'epoch_queue_starts'] and runner.model.module.queue is None:
runner.model.module.queue = torch.zeros(
len(runner.model.module.config['crops_for_assign']),
self.queue_length // self.world_size,
runner.model.module.feat_dim,
).cuda()
return
def after_train_epoch(self, runner):
if runner.model.module.queue is not None:
torch.save({'queue': runner.model.module.queue},
runner.model.module.queue_path)