Update type hint for register_notrace_module

register_notrace_module is used to decorate types (i.e. subclasses of nn.Module).
It is not called on module instances.
This commit is contained in:
Jasha10 2022-07-22 16:59:55 -05:00 committed by GitHub
parent d7b55a9429
commit 56c3a84db3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,7 +1,7 @@
""" PyTorch FX Based Feature Extraction Helpers """ PyTorch FX Based Feature Extraction Helpers
Using https://pytorch.org/vision/stable/feature_extraction.html Using https://pytorch.org/vision/stable/feature_extraction.html
""" """
from typing import Callable, List, Dict, Union from typing import Callable, List, Dict, Union, Type
import torch import torch
from torch import nn from torch import nn
@ -35,7 +35,7 @@ except ImportError:
pass pass
def register_notrace_module(module: nn.Module): def register_notrace_module(module: Type[nn.Module]):
""" """
Any module not under timm.models.layers should get this decorator if we don't want to trace through it. Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
""" """