This commit is contained in:
Glenn Jocher 2021-12-19 18:17:52 +01:00
parent 35d12e6f6d
commit 8e71c9daaa

View File

@ -19,6 +19,7 @@ import math
import os import os
import sys import sys
from copy import deepcopy from copy import deepcopy
from datetime import datetime
from pathlib import Path from pathlib import Path
import torch import torch
@ -39,7 +40,11 @@ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
from models.common import Classify, DetectMultiBackend from models.common import Classify, DetectMultiBackend
from utils.general import NUM_THREADS, download, check_file, increment_path, check_git_status, check_requirements, \ from utils.general import NUM_THREADS, download, check_file, increment_path, check_git_status, check_requirements, \
colorstr colorstr
from utils.torch_utils import model_info, select_device, is_parallel from utils.torch_utils import model_info, select_device, de_parallel
# Functions
normalize = lambda x, mean=0.5, std=0.25: (x - mean) / std
denormalize = lambda x, mean=0.5, std=0.25: x * std + mean
def train(): def train():
@ -167,8 +172,9 @@ def train():
if (not opt.nosave) or final_epoch: if (not opt.nosave) or final_epoch:
ckpt = {'epoch': epoch, ckpt = {'epoch': epoch,
'best_fitness': best_fitness, 'best_fitness': best_fitness,
'model': deepcopy(model.module if is_parallel(model) else model).half(), 'model': deepcopy(de_parallel(model)).half(),
'optimizer': None} 'optimizer': None, # optimizer.state_dict()
'date': datetime.now().isoformat()}
# Save last, best and delete # Save last, best and delete
torch.save(ckpt, last) torch.save(ckpt, last)
@ -223,7 +229,6 @@ def classify(model, size=128, file='../datasets/mnist/test/3/30.png', plot=False
import torch.nn.functional as F import torch.nn.functional as F
resize = torch.nn.Upsample(size=(size, size), mode='bilinear', align_corners=False) # image resize resize = torch.nn.Upsample(size=(size, size), mode='bilinear', align_corners=False) # image resize
normalize = lambda x, mean=0.5, std=0.25: (x - mean) / std
# Image # Image
im = cv2.imread(str(file))[..., ::-1] # HWC, BGR to RGB im = cv2.imread(str(file))[..., ::-1] # HWC, BGR to RGB
@ -301,11 +306,7 @@ if __name__ == '__main__':
cuda = device.type != 'cpu' cuda = device.type != 'cpu'
opt.hyp = check_file(opt.hyp) # check files opt.hyp = check_file(opt.hyp) # check files
opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve) # increment run opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve) # increment run
# Functions
resize = torch.nn.Upsample(size=(opt.img_size, opt.img_size), mode='bilinear', align_corners=False) # image resize resize = torch.nn.Upsample(size=(opt.img_size, opt.img_size), mode='bilinear', align_corners=False) # image resize
normalize = lambda x, mean=0.5, std=0.25: (x - mean) / std
denormalize = lambda x, mean=0.5, std=0.25: x * std + mean
# Train # Train
train() train()