mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
adapt_input_conv: add type hints
This commit is contained in:
parent
105a667baa
commit
c68d724e9c
@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Type, Union
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from timm.layers import use_reentrant_ckpt
|
||||
|
||||
@ -284,7 +285,7 @@ def checkpoint_seq(
|
||||
return x
|
||||
|
||||
|
||||
def adapt_input_conv(in_chans, conv_weight):
|
||||
def adapt_input_conv(in_chans: int, conv_weight: Tensor) -> Tensor:
|
||||
conv_type = conv_weight.dtype
|
||||
conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU
|
||||
O, I, J, K = conv_weight.shape
|
||||
|
Loading…
x
Reference in New Issue
Block a user