2023-10-07 23:20:40 +08:00
# Adapted from https://github.com/microsoft/GLIP. The original liscense is:
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2023-10-07 23:02:26 +08:00
r """
Basic training script for PyTorch
"""
# Set up custom environment before nearly anything else is imported
# NOTE: this should be the first import (no not reorder)
from maskrcnn_benchmark . utils . env import setup_environment # noqa F401 isort:skip
import argparse
import os
import numpy as np
import torch
from maskrcnn_benchmark . config import cfg , try_to_find
from maskrcnn_benchmark . data import make_data_loader
from maskrcnn_benchmark . solver import make_lr_scheduler
from maskrcnn_benchmark . solver import make_optimizer
from maskrcnn_benchmark . engine . inference import inference
from maskrcnn_benchmark . engine . trainer import do_train
from maskrcnn_benchmark . modeling . detector import build_detection_model
from maskrcnn_benchmark . utils . checkpoint import DetectronCheckpointer
from maskrcnn_benchmark . utils . collect_env import collect_env_info
from maskrcnn_benchmark . utils . comm import synchronize , get_rank , is_main_process , all_gather
from maskrcnn_benchmark . utils . imports import import_file
from maskrcnn_benchmark . utils . logger import setup_logger
from maskrcnn_benchmark . utils . metric_logger import ( MetricLogger , TensorboardLogger )
from maskrcnn_benchmark . utils . miscellaneous import mkdir , save_config
import random
from maskrcnn_benchmark . utils . amp import autocast , GradScaler
from pathlib import Path
from tqdm import tqdm
from collections import defaultdict
os . environ [ ' CUDA_LAUNCH_BLOCKING ' ] = ' 1 '
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
def tuning_highlevel_override ( cfg , ) :
if cfg . SOLVER . TUNING_HIGHLEVEL_OVERRIDE == " vision_query " :
cfg . MODEL . BACKBONE . FREEZE = True
cfg . MODEL . FPN . FREEZE = True
cfg . MODEL . RPN . FREEZE = True if not cfg . VISION_QUERY . QUERY_FUSION else False
cfg . MODEL . LINEAR_PROB = False
cfg . MODEL . DYHEAD . FUSE_CONFIG . ADD_LINEAR_LAYER = False
cfg . MODEL . LANGUAGE_BACKBONE . FREEZE = False
cfg . MODEL . DYHEAD . USE_CHECKPOINT = False # Disable checkpoint
cfg . VISION_QUERY . ENABLED = True
if cfg . SOLVER . TUNING_HIGHLEVEL_OVERRIDE == " vs_with_txt_enc " :
cfg . MODEL . BACKBONE . FREEZE = True
cfg . MODEL . FPN . FREEZE = True
cfg . MODEL . RPN . FREEZE = True if not cfg . VISION_QUERY . QUERY_FUSION else False
cfg . MODEL . LINEAR_PROB = False
cfg . MODEL . DYHEAD . FUSE_CONFIG . ADD_LINEAR_LAYER = False
cfg . MODEL . LANGUAGE_BACKBONE . FREEZE = False
cfg . MODEL . DYHEAD . USE_CHECKPOINT = False # Disable checkpoint
cfg . VISION_QUERY . ENABLED = True
def train ( cfg , local_rank , distributed , use_tensorboard = False , resume = False ) :
model = build_detection_model ( cfg )
device = torch . device ( cfg . MODEL . DEVICE )
model . to ( device )
if cfg . GROUNDINGDINO . enabled :
if cfg . SOLVER . TUNING_HIGHLEVEL_OVERRIDE == " vision_query " :
for key , p in model . named_parameters ( ) :
if not ( ' pre_select ' in key or ' qv_layer ' in key ) :
p . requires_grad = False
else :
if cfg . SOLVER . TUNING_HIGHLEVEL_OVERRIDE == " vision_query " :
if model . language_backbone is not None :
for key , p in model . language_backbone . named_parameters ( ) :
if not ( ' pre_select ' in key or ' qv_layer ' in key ) :
p . requires_grad = False
if cfg . VISION_QUERY . QUERY_FUSION :
if model . rpn is not None :
for key , p in model . rpn . named_parameters ( ) :
if not ( ' pre_select ' in key or ' qv_layer ' in key ) :
p . requires_grad = False
if cfg . MODEL . BACKBONE . RESET_BN :
for name , param in model . named_buffers ( ) :
if ' running_mean ' in name :
torch . nn . init . constant_ ( param , 0 )
if ' running_var ' in name :
torch . nn . init . constant_ ( param , 1 )
if cfg . SOLVER . GRAD_CLIP > 0 :
clip_value = cfg . SOLVER . GRAD_CLIP
for p in filter ( lambda p : p . grad is not None , model . parameters ( ) ) :
p . register_hook ( lambda grad : torch . clamp ( grad , - clip_value , clip_value ) )
data_loader = make_data_loader (
cfg ,
is_train = True ,
is_distributed = distributed ,
start_iter = 0 # <TODO> Sample data from resume is disabled, due to the conflict with max_epoch
)
if cfg . TEST . DURING_TRAINING or cfg . SOLVER . USE_AUTOSTEP :
data_loaders_val = make_data_loader ( cfg , is_train = False , is_distributed = distributed )
data_loaders_val = data_loaders_val [ 0 ]
else :
data_loaders_val = None
if cfg . GROUNDINGDINO . enabled :
pass
else :
if cfg . MODEL . BACKBONE . FREEZE :
for p in model . backbone . body . parameters ( ) :
p . requires_grad = False
if cfg . MODEL . LANGUAGE_BACKBONE . FREEZE :
print ( " LANGUAGE_BACKBONE FROZEN. " )
for p in model . language_backbone . body . parameters ( ) :
p . requires_grad = False
if cfg . MODEL . FPN . FREEZE :
for p in model . backbone . fpn . parameters ( ) :
p . requires_grad = False
if cfg . MODEL . RPN . FREEZE :
for p in model . rpn . parameters ( ) :
p . requires_grad = False
# if cfg.SOLVER.PROMPT_PROBING_LEVEL != -1:
# if cfg.SOLVER.PROMPT_PROBING_LEVEL == 1:
# for p in model.parameters():
# p.requires_grad = False
# for p in model.language_backbone.body.parameters():
# p.requires_grad = True
# for name, p in model.named_parameters():
# if p.requires_grad:
# print(name, " : Not Frozen")
# else:
# print(name, " : Frozen")
# else:
# assert(0)
optimizer = make_optimizer ( cfg , model )
print ( ' Making scheduler ' )
scheduler = make_lr_scheduler ( cfg , optimizer )
if distributed :
print ( ' Distributing model ' )
model = torch . nn . parallel . DistributedDataParallel (
model , device_ids = [ local_rank ] , output_device = local_rank ,
broadcast_buffers = cfg . MODEL . BACKBONE . USE_BN ,
find_unused_parameters = cfg . SOLVER . FIND_UNUSED_PARAMETERS
)
print ( ' Done ' )
arguments = { }
arguments [ " iteration " ] = 0
output_dir = cfg . OUTPUT_DIR
save_to_disk = get_rank ( ) == 0
print ( ' Making checkpointer ' )
checkpointer = DetectronCheckpointer (
cfg , model , optimizer , scheduler , output_dir , save_to_disk
)
if resume and cfg . OUTPUT_DIR != " OUTPUT " :
if not os . path . exists ( cfg . OUTPUT_DIR ) :
load_weight = cfg . MODEL . WEIGHT
else :
checkpoint_list = [ name for name in os . listdir ( cfg . OUTPUT_DIR ) if name . endswith ( ' .pth ' ) and ' final ' not in name and ' resume ' not in name ]
if len ( checkpoint_list ) == 0 :
load_weight = cfg . MODEL . WEIGHT
resume = False
else :
max_bits = len ( checkpoint_list [ 0 ] . split ( ' . ' ) [ 0 ] . split ( ' _ ' ) [ - 1 ] )
iter_list = [ int ( name . split ( ' . ' ) [ 0 ] . split ( ' _ ' ) [ - 1 ] ) for name in checkpoint_list ]
max_iter = str ( max ( iter_list ) ) . zfill ( max_bits )
resume_weight_name = ' model_ ' + max_iter + ' .pth '
load_weight = str ( Path ( cfg . OUTPUT_DIR , resume_weight_name ) )
else :
load_weight = cfg . MODEL . WEIGHT
print ( ' Loading checkpoint ' )
extra_checkpoint_data = checkpointer . load ( try_to_find ( load_weight ) )
arguments . update ( extra_checkpoint_data )
# enable resume
data_loader . batch_sampler . start_iter = arguments [ " iteration " ] + 1 if resume else 0
# data_loader = make_data_loader(
# cfg,
# is_train=True,
# is_distributed=distributed,
# start_iter=arguments["iteration"] + 1 if resume else 0 # <TODO> Sample data from resume is disabled, due to the conflict with max_epoch
# )
# if cfg.TEST.DURING_TRAINING or cfg.SOLVER.USE_AUTOSTEP:
# data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=distributed)
# data_loaders_val = data_loaders_val[0]
# else:
# data_loaders_val = None
if cfg . DATASETS . FEW_SHOT :
arguments [ " dataset_ids " ] = data_loader . dataset . ids
checkpoint_period = cfg . SOLVER . CHECKPOINT_PERIOD
if use_tensorboard :
meters = TensorboardLogger (
log_dir = cfg . OUTPUT_DIR ,
start_iter = arguments [ " iteration " ] ,
delimiter = " "
)
else :
meters = MetricLogger ( delimiter = " " )
if is_main_process ( ) :
for name , p in model . named_parameters ( ) :
if p . requires_grad :
print ( name , " : Not Frozen " )
else :
print ( name , " : Frozen " )
do_train (
cfg ,
model ,
data_loader ,
optimizer ,
scheduler ,
checkpointer ,
device ,
checkpoint_period ,
arguments ,
data_loaders_val ,
meters ,
)
return model
def setup_for_distributed ( is_master ) :
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__ . print
def print ( * args , * * kwargs ) :
force = kwargs . pop ( " force " , False )
if is_master or force :
builtin_print ( * args , * * kwargs )
__builtin__ . print = print
def extract_query ( cfg ) :
if cfg . DATASETS . FEW_SHOT :
assert cfg . DATASETS . FEW_SHOT == cfg . VISION_QUERY . MAX_QUERY_NUMBER , ' To extract the right query instances, set VISION_QUERY.MAX_QUERY_NUMBER = DATASETS.FEW_SHOT. '
# if cfg.num_gpus > 1:
# max_query_number = cfg.VISION_QUERY.MAX_QUERY_NUMBER
# cfg.defrost()
# cfg.VISION_QUERY.MAX_QUERY_NUMBER = int(cfg.VISION_QUERY.MAX_QUERY_NUMBER/cfg.num_gpus)
# cfg.freeze()
model = build_detection_model ( cfg )
device = torch . device ( cfg . MODEL . DEVICE )
model . to ( device )
checkpointer = DetectronCheckpointer (
cfg , model
)
checkpointer . load ( try_to_find ( cfg . MODEL . WEIGHT ) )
data_loader = make_data_loader (
cfg ,
is_train = False ,
is_cache = True ,
is_distributed = cfg . num_gpus > 1 ,
)
assert isinstance ( data_loader , list ) and len ( data_loader ) == 1
data_loader = data_loader [ 0 ]
# if cfg.VISION_QUERY.CUSTOM_DATA_IDS is not None:
# data_loader.dataset.ids = cfg.VISION_QUERY.CUSTOM_DATA_IDS
if cfg . num_gpus > 1 :
model = torch . nn . parallel . DistributedDataParallel (
model , device_ids = [ cfg . local_rank ] , output_device = cfg . local_rank ,
broadcast_buffers = cfg . MODEL . BACKBONE . USE_BN ,
find_unused_parameters = cfg . SOLVER . FIND_UNUSED_PARAMETERS
)
query_images = defaultdict ( list )
_iterator = tqdm ( data_loader )
# _iterator = data_loader # for debug
model . eval ( )
for i , batch in enumerate ( _iterator ) :
images , targets , * _ = batch
if cfg . num_gpus > 1 :
query_images = model . module . extract_query ( images . to ( device ) , targets , query_images )
else :
query_images = model . extract_query ( images . to ( device ) , targets , query_images )
if cfg . num_gpus > 1 :
## not stable when using all_gather, easy to OOM.
# all_query_images = all_gather(query_images)
# if is_main_process():
# accumulated_query_images = defaultdict(list)
# for r, query_images_dict in enumerate(all_query_images):
# print('accumulating results: {}/{}'.format(r, len(all_query_images)))
# for label, feat in query_images_dict.items():
# num_queries=len(accumulated_query_images[label])
# if num_queries >= cfg.VISION_QUERY.MAX_QUERY_NUMBER:
# continue
# if num_queries==0:
# accumulated_query_images[label] = feat.to(device)
# else:
# accumulated_query_images[label] = torch.cat([accumulated_query_images[label].to(device), feat.to(device)])
# save_name = 'MODEL/{}_query_{}_pool{}_{}{}_multi-node.pth'.format(cfg.VISION_QUERY.DATASET_NAME if cfg.VISION_QUERY.DATASET_NAME else cfg.DATASETS.TRAIN[0].split('_')[0] , cfg.VISION_QUERY.MAX_QUERY_NUMBER, cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION ,'sel' if cfg.VISION_QUERY.SELECT_FPN_LEVEL else 'all', cfg.VISION_QUERY.QUERY_ADDITION_NAME)
# print('saving to ', save_name)
# torch.save(accumulated_query_images, save_name)
if cfg . VISION_QUERY . QUERY_BANK_SAVE_PATH != ' ' :
raise NotImplementedError
global_rank = get_rank ( )
save_name = ' MODEL/ {} _query_ {} _pool {} _ {} {} _rank {} .pth ' . format ( cfg . VISION_QUERY . DATASET_NAME if cfg . VISION_QUERY . DATASET_NAME else cfg . DATASETS . TRAIN [ 0 ] . split ( ' _ ' ) [ 0 ] , cfg . VISION_QUERY . MAX_QUERY_NUMBER , cfg . MODEL . ROI_BOX_HEAD . POOLER_RESOLUTION , ' sel ' if cfg . VISION_QUERY . SELECT_FPN_LEVEL else ' all ' , cfg . VISION_QUERY . QUERY_ADDITION_NAME , global_rank )
print ( ' saving to ' , save_name )
torch . save ( query_images , save_name )
else :
if cfg . VISION_QUERY . QUERY_BANK_SAVE_PATH != ' ' :
save_name = cfg . VISION_QUERY . QUERY_BANK_SAVE_PATH
else :
save_name = ' MODEL/ {} _query_ {} _pool {} _ {} {} .pth ' . format ( cfg . VISION_QUERY . DATASET_NAME if cfg . VISION_QUERY . DATASET_NAME else cfg . DATASETS . TRAIN [ 0 ] . split ( ' _ ' ) [ 0 ] , cfg . VISION_QUERY . MAX_QUERY_NUMBER , cfg . MODEL . ROI_BOX_HEAD . POOLER_RESOLUTION , ' sel ' if cfg . VISION_QUERY . SELECT_FPN_LEVEL else ' all ' , cfg . VISION_QUERY . QUERY_ADDITION_NAME )
print ( ' saving to ' , save_name )
torch . save ( query_images , save_name )
# if cfg.num_gpus > 1:
# #
# world_size = torch.distributed.dist.get_world_size()
# if is_main_process():
# query_images_list = []
# for r in range(world_size):
# saved_path = 'MODEL/{}_query_{}_pool{}_{}{}_rank{}.pth'.format(cfg.VISION_QUERY.DATASET_NAME if cfg.VISION_QUERY.DATASET_NAME else cfg.DATASETS.TRAIN[0].split('_')[0] , cfg.VISION_QUERY.MAX_QUERY_NUMBER, cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION ,'sel' if cfg.VISION_QUERY.SELECT_FPN_LEVEL else 'all', cfg.VISION_QUERY.QUERY_ADDITION_NAME, r)
# query_images_list.append(torch.load(saved_path, map_location='cpu'))
# for s in query_images_list
def main ( ) :
parser = argparse . ArgumentParser ( description = " PyTorch Object Detection Training " )
parser . add_argument (
" --config-file " ,
default = " " ,
metavar = " FILE " ,
help = " path to config file " ,
type = str ,
)
parser . add_argument ( " --local-rank " , type = int , default = 0 )
parser . add_argument (
" --skip-test " ,
dest = " skip_test " ,
help = " Do not test the final model " ,
action = " store_true " ,
)
parser . add_argument ( " --use-tensorboard " ,
dest = " use_tensorboard " ,
help = " Use tensorboardX logger (Requires tensorboardX installed) " ,
action = " store_true " ,
default = False
)
parser . add_argument (
" opts " ,
help = " Modify config options using the command-line " ,
default = None ,
nargs = argparse . REMAINDER ,
)
parser . add_argument ( " --save_original_config " , action = " store_true " )
parser . add_argument ( " --disable_output_distributed " , action = " store_true " )
parser . add_argument ( " --override_output_dir " , default = None )
parser . add_argument ( " --custom_shot_and_epoch_and_general_copy " , default = None , type = str )
parser . add_argument ( " --resume " , action = " store_true " , default = False )
parser . add_argument ( " --extract_query " , action = " store_true " , default = False )
parser . add_argument (
" --task_config " ,
default = " " ,
metavar = " FILE " ,
help = " path to config file " ,
type = str ,
)
parser . add_argument (
" --additional_model_config " ,
default = " " ,
metavar = " FILE " ,
help = " path to config file " ,
type = str ,
)
args = parser . parse_args ( )
num_gpus = int ( os . environ [ " WORLD_SIZE " ] ) if " WORLD_SIZE " in os . environ else 1
args . distributed = num_gpus > 1
if args . distributed :
import datetime
torch . cuda . set_device ( args . local_rank )
torch . distributed . init_process_group (
backend = " nccl " , init_method = " env:// " ,
timeout = datetime . timedelta ( 0 , 7200 )
)
if args . disable_output_distributed :
setup_for_distributed ( args . local_rank < = 0 )
cfg . local_rank = args . local_rank
cfg . num_gpus = num_gpus
cfg . merge_from_file ( args . config_file )
if args . task_config :
cfg . merge_from_file ( args . task_config )
if args . additional_model_config :
cfg . merge_from_file ( args . additional_model_config )
cfg . merge_from_list ( args . opts )
# specify output dir for models
if args . override_output_dir :
cfg . OUTPUT_DIR = args . override_output_dir
tuning_highlevel_override ( cfg )
cfg . freeze ( )
seed = cfg . SOLVER . SEED + args . local_rank
torch . manual_seed ( seed )
np . random . seed ( seed )
random . seed ( seed )
output_dir = cfg . OUTPUT_DIR
if output_dir :
mkdir ( output_dir )
logger = setup_logger ( " maskrcnn_benchmark " , output_dir , get_rank ( ) )
logger . info ( args )
logger . info ( " Using {} GPUs " . format ( num_gpus ) )
logger . info ( " Loaded configuration file {} " . format ( args . config_file ) )
with open ( args . config_file , " r " ) as cf :
config_str = " \n " + cf . read ( )
logger . info ( config_str )
logger . info ( " Running with config: \n {} " . format ( cfg ) )
output_config_path = os . path . join ( cfg . OUTPUT_DIR , ' config.yml ' )
logger . info ( " Saving config into: {} " . format ( output_config_path ) )
# save overloaded model config in the output directory
if args . save_original_config :
import shutil
shutil . copy ( args . config_file , os . path . join ( cfg . OUTPUT_DIR , ' config_original.yml ' ) )
save_config ( cfg , output_config_path )
if args . extract_query :
extract_query ( cfg )
else :
# if cfg.DATASETS.FEW_SHOT and cfg.VISION_QUERY.ENABLED:
# max_query_number = int(os.path.basename(cfg.VISION_QUERY.QUERY_BANK_PATH).split('_')[-2])
# assert cfg.DATASETS.FEW_SHOT == max_query_number, 'You should first extract corresponding few-shot query instances.'
# assert max_query_number >= cfg.VISION_QUERY.NUM_QUERY_PER_CLASS
model = train ( cfg = cfg ,
local_rank = args . local_rank ,
distributed = args . distributed ,
use_tensorboard = args . use_tensorboard ,
resume = args . resume )
if __name__ == " __main__ " :
main ( )