adapt_input_conv: add type hints

This commit is contained in:
Adam J. Stewart 2025-02-21 19:24:05 +01:00 committed by Ross Wightman
parent 105a667baa
commit c68d724e9c

View File

@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Type, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn as nn from torch import nn as nn
from torch import Tensor
from timm.layers import use_reentrant_ckpt from timm.layers import use_reentrant_ckpt
@ -284,7 +285,7 @@ def checkpoint_seq(
return x return x
def adapt_input_conv(in_chans, conv_weight): def adapt_input_conv(in_chans: int, conv_weight: Tensor) -> Tensor:
conv_type = conv_weight.dtype conv_type = conv_weight.dtype
conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU
O, I, J, K = conv_weight.shape O, I, J, K = conv_weight.shape