mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
Update
This commit is contained in:
parent
35d12e6f6d
commit
8e71c9daaa
@ -19,6 +19,7 @@ import math
|
||||
import os
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
@ -39,7 +40,11 @@ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
||||
from models.common import Classify, DetectMultiBackend
|
||||
from utils.general import NUM_THREADS, download, check_file, increment_path, check_git_status, check_requirements, \
|
||||
colorstr
|
||||
from utils.torch_utils import model_info, select_device, is_parallel
|
||||
from utils.torch_utils import model_info, select_device, de_parallel
|
||||
|
||||
# Functions
|
||||
normalize = lambda x, mean=0.5, std=0.25: (x - mean) / std
|
||||
denormalize = lambda x, mean=0.5, std=0.25: x * std + mean
|
||||
|
||||
|
||||
def train():
|
||||
@ -167,8 +172,9 @@ def train():
|
||||
if (not opt.nosave) or final_epoch:
|
||||
ckpt = {'epoch': epoch,
|
||||
'best_fitness': best_fitness,
|
||||
'model': deepcopy(model.module if is_parallel(model) else model).half(),
|
||||
'optimizer': None}
|
||||
'model': deepcopy(de_parallel(model)).half(),
|
||||
'optimizer': None, # optimizer.state_dict()
|
||||
'date': datetime.now().isoformat()}
|
||||
|
||||
# Save last, best and delete
|
||||
torch.save(ckpt, last)
|
||||
@ -223,7 +229,6 @@ def classify(model, size=128, file='../datasets/mnist/test/3/30.png', plot=False
|
||||
import torch.nn.functional as F
|
||||
|
||||
resize = torch.nn.Upsample(size=(size, size), mode='bilinear', align_corners=False) # image resize
|
||||
normalize = lambda x, mean=0.5, std=0.25: (x - mean) / std
|
||||
|
||||
# Image
|
||||
im = cv2.imread(str(file))[..., ::-1] # HWC, BGR to RGB
|
||||
@ -301,11 +306,7 @@ if __name__ == '__main__':
|
||||
cuda = device.type != 'cpu'
|
||||
opt.hyp = check_file(opt.hyp) # check files
|
||||
opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve) # increment run
|
||||
|
||||
# Functions
|
||||
resize = torch.nn.Upsample(size=(opt.img_size, opt.img_size), mode='bilinear', align_corners=False) # image resize
|
||||
normalize = lambda x, mean=0.5, std=0.25: (x - mean) / std
|
||||
denormalize = lambda x, mean=0.5, std=0.25: x * std + mean
|
||||
|
||||
# Train
|
||||
train()
|
||||
|
Loading…
x
Reference in New Issue
Block a user