torchvision nms bug fix
parent
66676eb039
commit
124f0e8212
|
@ -8,6 +8,7 @@ import torch
|
|||
import torch.backends.cudnn as cudnn
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -151,7 +152,6 @@ def model_info(model, verbose=False):
|
|||
|
||||
def load_classifier(name='resnet101', n=2):
|
||||
# Loads a pretrained model reshaped to n-class output
|
||||
import torchvision
|
||||
model = torchvision.models.__dict__[name](pretrained=True)
|
||||
|
||||
# ResNet model properties
|
||||
|
|
Loading…
Reference in New Issue