Increase FLOPS robustness (#1608)
parent
ba48f867ea
commit
8918e63476
|
@ -1,12 +1,12 @@
|
||||||
# PyTorch utils
|
# PyTorch utils
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import math
|
||||||
import torch
|
import torch
|
||||||
import torch.backends.cudnn as cudnn
|
import torch.backends.cudnn as cudnn
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -152,7 +152,7 @@ def model_info(model, verbose=False, img_size=640):
|
||||||
|
|
||||||
try: # FLOPS
|
try: # FLOPS
|
||||||
from thop import profile
|
from thop import profile
|
||||||
stride = int(model.stride.max())
|
stride = int(model.stride.max()) if hasattr(model, 'stride') else 32
|
||||||
img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device) # input
|
img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device) # input
|
||||||
flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride FLOPS
|
flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride FLOPS
|
||||||
img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
|
img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
|
||||||
|
|
Loading…
Reference in New Issue