Attach transforms to model (#9028)

* Attach transforms to model

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>

* Update val.py

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>

* Update train.py

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/9022/head^2
Glenn Jocher 2022-08-19 01:59:51 +02:00 committed by GitHub
parent 4bc5520e94
commit 840b7232db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 7 deletions

View File

@ -122,16 +122,16 @@ def train(opt, device):
for p in model.parameters():
p.requires_grad = True # for training
model = model.to(device)
names = trainloader.dataset.classes # class names
model.names = names # attach class names
# Info
if RANK in {-1, 0}:
model.names = trainloader.dataset.classes # attach class names
model.transforms = testloader.dataset.torch_transforms # attach inference transforms
model_info(model)
if opt.verbose:
LOGGER.info(model)
images, labels = next(iter(trainloader))
file = imshow_cls(images[:25], labels[:25], names=names, f=save_dir / 'train_images.jpg')
file = imshow_cls(images[:25], labels[:25], names=model.names, f=save_dir / 'train_images.jpg')
logger.log_images(file, name='Train Examples')
logger.log_graph(model, imgsz) # log model
@ -254,8 +254,8 @@ def train(opt, device):
# Plot examples
images, labels = (x[:25] for x in next(iter(testloader))) # first 25 images and labels
pred = torch.max(ema.ema((images.half() if cuda else images.float()).to(device)), 1)[1]
file = imshow_cls(images, labels, pred, names, verbose=False, f=save_dir / 'test_images.jpg')
pred = torch.max(ema.ema(images.to(device)), 1)[1]
file = imshow_cls(images, labels, pred, model.names, verbose=False, f=save_dir / 'test_images.jpg')
# Log results
meta = {"epochs": epochs, "top1_acc": best_fitness, "date": datetime.now().isoformat()}

View File

@ -39,7 +39,7 @@ def run(
project=ROOT / 'runs/val-cls', # save to project/name
name='exp', # save to project/name
exist_ok=False, # existing project/name ok, do not increment
half=True, # use FP16 half-precision inference
half=False, # use FP16 half-precision inference
dnn=False, # use OpenCV DNN for ONNX inference
model=None,
dataloader=None,
@ -124,7 +124,6 @@ def run(
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms post-process per image at shape {shape}' % t)
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
model.float() # for training
return top1, top5, loss