mirror of https://github.com/WongKinYiu/yolov7.git
fix training with frozen layers (#378)
parent
1e51f564e0
commit
b8956dd5a5
7
train.py
7
train.py
|
@ -40,8 +40,8 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def train(hyp, opt, device, tb_writer=None):
|
def train(hyp, opt, device, tb_writer=None):
|
||||||
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
|
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
|
||||||
save_dir, epochs, batch_size, total_batch_size, weights, rank = \
|
save_dir, epochs, batch_size, total_batch_size, weights, rank, freeze = \
|
||||||
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
|
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, opt.freeze
|
||||||
|
|
||||||
# Directories
|
# Directories
|
||||||
wdir = save_dir / 'weights'
|
wdir = save_dir / 'weights'
|
||||||
|
@ -99,7 +99,7 @@ def train(hyp, opt, device, tb_writer=None):
|
||||||
test_path = data_dict['val']
|
test_path = data_dict['val']
|
||||||
|
|
||||||
# Freeze
|
# Freeze
|
||||||
freeze = [] # parameter names to freeze (full or partial)
|
freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # parameter names to freeze (full or partial)
|
||||||
for k, v in model.named_parameters():
|
for k, v in model.named_parameters():
|
||||||
v.requires_grad = True # train all layers
|
v.requires_grad = True # train all layers
|
||||||
if any(x in k for x in freeze):
|
if any(x in k for x in freeze):
|
||||||
|
@ -555,6 +555,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B')
|
parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B')
|
||||||
parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
|
parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
|
||||||
parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
|
parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
|
||||||
|
parser.add_argument('--freeze', nargs='+', type=int, default=[0], help='Freeze layers: backbone of yolov7=50, first3=0 1 2')
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
|
|
||||||
# Set DDP variables
|
# Set DDP variables
|
||||||
|
|
Loading…
Reference in New Issue