Attach transforms to model (#9028)
* Attach transforms to model Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update val.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update train.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>pull/9022/head^2
parent
4bc5520e94
commit
840b7232db
|
@ -122,16 +122,16 @@ def train(opt, device):
|
|||
for p in model.parameters():
|
||||
p.requires_grad = True # for training
|
||||
model = model.to(device)
|
||||
names = trainloader.dataset.classes # class names
|
||||
model.names = names # attach class names
|
||||
|
||||
# Info
|
||||
if RANK in {-1, 0}:
|
||||
model.names = trainloader.dataset.classes # attach class names
|
||||
model.transforms = testloader.dataset.torch_transforms # attach inference transforms
|
||||
model_info(model)
|
||||
if opt.verbose:
|
||||
LOGGER.info(model)
|
||||
images, labels = next(iter(trainloader))
|
||||
file = imshow_cls(images[:25], labels[:25], names=names, f=save_dir / 'train_images.jpg')
|
||||
file = imshow_cls(images[:25], labels[:25], names=model.names, f=save_dir / 'train_images.jpg')
|
||||
logger.log_images(file, name='Train Examples')
|
||||
logger.log_graph(model, imgsz) # log model
|
||||
|
||||
|
@ -254,8 +254,8 @@ def train(opt, device):
|
|||
|
||||
# Plot examples
|
||||
images, labels = (x[:25] for x in next(iter(testloader))) # first 25 images and labels
|
||||
pred = torch.max(ema.ema((images.half() if cuda else images.float()).to(device)), 1)[1]
|
||||
file = imshow_cls(images, labels, pred, names, verbose=False, f=save_dir / 'test_images.jpg')
|
||||
pred = torch.max(ema.ema(images.to(device)), 1)[1]
|
||||
file = imshow_cls(images, labels, pred, model.names, verbose=False, f=save_dir / 'test_images.jpg')
|
||||
|
||||
# Log results
|
||||
meta = {"epochs": epochs, "top1_acc": best_fitness, "date": datetime.now().isoformat()}
|
||||
|
|
|
@ -39,7 +39,7 @@ def run(
|
|||
project=ROOT / 'runs/val-cls', # save to project/name
|
||||
name='exp', # save to project/name
|
||||
exist_ok=False, # existing project/name ok, do not increment
|
||||
half=True, # use FP16 half-precision inference
|
||||
half=False, # use FP16 half-precision inference
|
||||
dnn=False, # use OpenCV DNN for ONNX inference
|
||||
model=None,
|
||||
dataloader=None,
|
||||
|
@ -124,7 +124,6 @@ def run(
|
|||
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms post-process per image at shape {shape}' % t)
|
||||
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
|
||||
|
||||
model.float() # for training
|
||||
return top1, top5, loss
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue