mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
Update train.py (#3667)
This commit is contained in:
parent
ac34834563
commit
6d6e2ca65f
13
train.py
13
train.py
@ -22,7 +22,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import test # import test.py to get mAP after each epoch
|
import test # for end-of-epoch mAP
|
||||||
from models.experimental import attempt_load
|
from models.experimental import attempt_load
|
||||||
from models.yolo import Model
|
from models.yolo import Model
|
||||||
from utils.autoanchor import check_anchors
|
from utils.autoanchor import check_anchors
|
||||||
@ -39,7 +39,11 @@ from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def train(hyp, opt, device, tb_writer=None):
|
def train(hyp,
|
||||||
|
opt,
|
||||||
|
device,
|
||||||
|
tb_writer=None
|
||||||
|
):
|
||||||
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
|
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
|
||||||
save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \
|
save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \
|
||||||
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \
|
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \
|
||||||
@ -404,12 +408,11 @@ def train(hyp, opt, device, tb_writer=None):
|
|||||||
torch.save(ckpt, best)
|
torch.save(ckpt, best)
|
||||||
if wandb_logger.wandb:
|
if wandb_logger.wandb:
|
||||||
if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1:
|
if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1:
|
||||||
wandb_logger.log_model(
|
wandb_logger.log_model(last.parent, opt, epoch, fi, best_model=best_fitness == fi)
|
||||||
last.parent, opt, epoch, fi, best_model=best_fitness == fi)
|
|
||||||
del ckpt
|
del ckpt
|
||||||
|
|
||||||
# end epoch ----------------------------------------------------------------------------------------------------
|
# end epoch ----------------------------------------------------------------------------------------------------
|
||||||
# end training
|
# end training -----------------------------------------------------------------------------------------------------
|
||||||
if rank in [-1, 0]:
|
if rank in [-1, 0]:
|
||||||
logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
|
logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
|
||||||
if plots:
|
if plots:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user