diff --git a/timm/models/_hub.py b/timm/models/_hub.py index 53f8e855..55ab04bf 100644 --- a/timm/models/_hub.py +++ b/timm/models/_hub.py @@ -2,7 +2,6 @@ import hashlib import json import logging import os -import sys from functools import partial from pathlib import Path from tempfile import TemporaryDirectory @@ -22,9 +21,9 @@ try: except ImportError: _has_safetensors = False -if sys.version_info >= (3, 8): +try: from typing import Literal -else: +except ImportError: from typing_extensions import Literal from timm import __version__ diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 666c18ee..380e3a64 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -27,7 +27,11 @@ import logging import math from collections import OrderedDict from functools import partial -from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, Union, Literal, List +from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, Union, List +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal import torch import torch.nn as nn