adapt_input_conv: add type hints

This commit is contained in:
Adam J. Stewart 2025-02-21 19:24:05 +01:00 committed by Ross Wightman
parent 105a667baa
commit c68d724e9c

View File

@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Type, Union
import torch
import torch.utils.checkpoint
from torch import nn as nn
from torch import Tensor
from timm.layers import use_reentrant_ckpt
@ -284,7 +285,7 @@ def checkpoint_seq(
return x
def adapt_input_conv(in_chans, conv_weight):
def adapt_input_conv(in_chans: int, conv_weight: Tensor) -> Tensor:
conv_type = conv_weight.dtype
conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU
O, I, J, K = conv_weight.shape