mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
adapt_input_conv: add type hints
This commit is contained in:
parent
105a667baa
commit
c68d724e9c
@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Type, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
from timm.layers import use_reentrant_ckpt
|
from timm.layers import use_reentrant_ckpt
|
||||||
|
|
||||||
@ -284,7 +285,7 @@ def checkpoint_seq(
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def adapt_input_conv(in_chans, conv_weight):
|
def adapt_input_conv(in_chans: int, conv_weight: Tensor) -> Tensor:
|
||||||
conv_type = conv_weight.dtype
|
conv_type = conv_weight.dtype
|
||||||
conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU
|
conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU
|
||||||
O, I, J, K = conv_weight.shape
|
O, I, J, K = conv_weight.shape
|
||||||
|
Loading…
x
Reference in New Issue
Block a user