Update type hint for register_notrace_module

register_notrace_module is used to decorate types (i.e. subclasses of nn.Module).
It is not called on module instances.
This commit is contained in:
Jasha10 2022-07-22 16:59:55 -05:00 committed by GitHub
parent d7b55a9429
commit 56c3a84db3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,7 +1,7 @@
""" PyTorch FX Based Feature Extraction Helpers
Using https://pytorch.org/vision/stable/feature_extraction.html
"""
from typing import Callable, List, Dict, Union
from typing import Callable, List, Dict, Union, Type
import torch
from torch import nn
@ -35,7 +35,7 @@ except ImportError:
pass
def register_notrace_module(module: nn.Module):
def register_notrace_module(module: Type[nn.Module]):
"""
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
"""