This commit is contained in:
Glenn Jocher 2021-12-19 18:17:52 +01:00
parent 35d12e6f6d
commit 8e71c9daaa

View File

@ -19,6 +19,7 @@ import math
import os
import sys
from copy import deepcopy
from datetime import datetime
from pathlib import Path
import torch
@ -39,7 +40,11 @@ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
from models.common import Classify, DetectMultiBackend
from utils.general import NUM_THREADS, download, check_file, increment_path, check_git_status, check_requirements, \
colorstr
from utils.torch_utils import model_info, select_device, is_parallel
from utils.torch_utils import model_info, select_device, de_parallel
# Functions
normalize = lambda x, mean=0.5, std=0.25: (x - mean) / std
denormalize = lambda x, mean=0.5, std=0.25: x * std + mean
def train():
@ -167,8 +172,9 @@ def train():
if (not opt.nosave) or final_epoch:
ckpt = {'epoch': epoch,
'best_fitness': best_fitness,
'model': deepcopy(model.module if is_parallel(model) else model).half(),
'optimizer': None}
'model': deepcopy(de_parallel(model)).half(),
'optimizer': None, # optimizer.state_dict()
'date': datetime.now().isoformat()}
# Save last, best and delete
torch.save(ckpt, last)
@ -223,7 +229,6 @@ def classify(model, size=128, file='../datasets/mnist/test/3/30.png', plot=False
import torch.nn.functional as F
resize = torch.nn.Upsample(size=(size, size), mode='bilinear', align_corners=False) # image resize
normalize = lambda x, mean=0.5, std=0.25: (x - mean) / std
# Image
im = cv2.imread(str(file))[..., ::-1] # HWC, BGR to RGB
@ -301,11 +306,7 @@ if __name__ == '__main__':
cuda = device.type != 'cpu'
opt.hyp = check_file(opt.hyp) # check files
opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve) # increment run
# Functions
resize = torch.nn.Upsample(size=(opt.img_size, opt.img_size), mode='bilinear', align_corners=False) # image resize
normalize = lambda x, mean=0.5, std=0.25: (x - mean) / std
denormalize = lambda x, mean=0.5, std=0.25: x * std + mean
# Train
train()