mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
support mambaout, metaformer, nest, nextvit, pvt_v2
This commit is contained in:
parent
2d3155908c
commit
6b6beffa6b
@ -6,7 +6,7 @@ MetaFormer (https://github.com/sail-sg/metaformer),
|
|||||||
InceptionNeXt (https://github.com/sail-sg/inceptionnext)
|
InceptionNeXt (https://github.com/sail-sg/inceptionnext)
|
||||||
"""
|
"""
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Optional
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -14,6 +14,7 @@ from torch import nn
|
|||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale, ClNormMlpClassifierHead, get_act_layer
|
from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale, ClNormMlpClassifierHead, get_act_layer
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._manipulate import checkpoint_seq
|
from ._manipulate import checkpoint_seq
|
||||||
from ._registry import register_model, generate_default_cfgs
|
from ._registry import register_model, generate_default_cfgs
|
||||||
|
|
||||||
@ -417,6 +418,68 @@ class MambaOut(nn.Module):
|
|||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.head.reset(num_classes, global_pool)
|
self.head.reset(num_classes, global_pool)
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
norm: Apply norm layer to compatible intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW', 'NHWC'), 'Output format must be one of NCHW or NHWC.'
|
||||||
|
channel_first = output_fmt == 'NCHW'
|
||||||
|
intermediates = []
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
x = self.stem(x)
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
stages = self.stages
|
||||||
|
else:
|
||||||
|
stages = self.stages[:max_index + 1]
|
||||||
|
|
||||||
|
for feat_idx, stage in enumerate(stages):
|
||||||
|
x = stage(x)
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
|
if channel_first:
|
||||||
|
# reshape to BCHW output format
|
||||||
|
intermediates = [y.permute(0, 3, 1, 2).contiguous() for y in intermediates]
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, '')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
x = self.stages(x)
|
x = self.stages(x)
|
||||||
|
@ -28,7 +28,7 @@ Adapted from https://github.com/sail-sg/metaformer, original copyright below
|
|||||||
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -40,6 +40,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|||||||
from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d, GroupNorm1, LayerNorm, LayerNorm2d, Mlp, \
|
from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d, GroupNorm1, LayerNorm, LayerNorm2d, Mlp, \
|
||||||
use_fused_attn
|
use_fused_attn
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._manipulate import checkpoint_seq
|
from ._manipulate import checkpoint_seq
|
||||||
from ._registry import generate_default_cfgs, register_model
|
from ._registry import generate_default_cfgs, register_model
|
||||||
|
|
||||||
@ -597,6 +598,62 @@ class MetaFormer(nn.Module):
|
|||||||
final = nn.Identity()
|
final = nn.Identity()
|
||||||
self.head.fc = final
|
self.head.fc = final
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
norm: Apply norm layer to compatible intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
|
intermediates = []
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
x = self.stem(x)
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
stages = self.stages
|
||||||
|
else:
|
||||||
|
stages = self.stages[:max_index + 1]
|
||||||
|
|
||||||
|
for feat_idx, stage in enumerate(stages):
|
||||||
|
x = stage(x)
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, '')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
def forward_head(self, x: Tensor, pre_logits: bool = False):
|
def forward_head(self, x: Tensor, pre_logits: bool = False):
|
||||||
# NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
|
# NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
|
||||||
x = self.head.global_pool(x)
|
x = self.head.global_pool(x)
|
||||||
|
@ -19,6 +19,7 @@ import collections.abc
|
|||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -28,6 +29,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|||||||
from timm.layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_, _assert
|
from timm.layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_, _assert
|
||||||
from timm.layers import create_conv2d, create_pool2d, to_ntuple, use_fused_attn, LayerNorm
|
from timm.layers import create_conv2d, create_pool2d, to_ntuple, use_fused_attn, LayerNorm
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._features_fx import register_notrace_function
|
from ._features_fx import register_notrace_function
|
||||||
from ._manipulate import checkpoint_seq, named_apply
|
from ._manipulate import checkpoint_seq, named_apply
|
||||||
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
|
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
|
||||||
@ -420,6 +422,67 @@ class Nest(nn.Module):
|
|||||||
self.global_pool, self.head = create_classifier(
|
self.global_pool, self.head = create_classifier(
|
||||||
self.num_features, self.num_classes, pool_type=global_pool)
|
self.num_features, self.num_classes, pool_type=global_pool)
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
norm: Apply norm layer to compatible intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
|
intermediates = []
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.levels), indices)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
stages = self.levels
|
||||||
|
else:
|
||||||
|
stages = self.levels[:max_index + 1]
|
||||||
|
|
||||||
|
for feat_idx, stage in enumerate(stages):
|
||||||
|
x = stage(x)
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
# Layer norm done over channel dim only (to NHWC and back)
|
||||||
|
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.levels), indices)
|
||||||
|
self.levels = self.levels[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||||
|
if prune_norm:
|
||||||
|
self.norm = nn.Identity()
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, '')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
x = self.levels(x)
|
x = self.levels(x)
|
||||||
|
@ -6,7 +6,7 @@ Next-ViT model defs and weights adapted from https://github.com/bytedance/Next-V
|
|||||||
"""
|
"""
|
||||||
# Copyright (c) ByteDance Inc. All rights reserved.
|
# Copyright (c) ByteDance Inc. All rights reserved.
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -16,6 +16,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|||||||
from timm.layers import DropPath, trunc_normal_, ConvMlp, get_norm_layer, get_act_layer, use_fused_attn
|
from timm.layers import DropPath, trunc_normal_, ConvMlp, get_norm_layer, get_act_layer, use_fused_attn
|
||||||
from timm.layers import ClassifierHead
|
from timm.layers import ClassifierHead
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._features_fx import register_notrace_function
|
from ._features_fx import register_notrace_function
|
||||||
from ._manipulate import checkpoint_seq
|
from ._manipulate import checkpoint_seq
|
||||||
from ._registry import generate_default_cfgs, register_model
|
from ._registry import generate_default_cfgs, register_model
|
||||||
@ -560,6 +561,66 @@ class NextViT(nn.Module):
|
|||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.head.reset(num_classes, pool_type=global_pool)
|
self.head.reset(num_classes, pool_type=global_pool)
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
norm: Apply norm layer to compatible intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
|
intermediates = []
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
x = self.stem(x)
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
stages = self.stages
|
||||||
|
else:
|
||||||
|
stages = self.stages[:max_index + 1]
|
||||||
|
|
||||||
|
for feat_idx, stage in enumerate(stages):
|
||||||
|
x = stage(x)
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||||
|
if prune_norm:
|
||||||
|
self.norm = nn.Identity()
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, '')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
|
@ -16,7 +16,7 @@ Modifications and timm support by / Copyright 2022, Ross Wightman
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -25,6 +25,7 @@ import torch.nn.functional as F
|
|||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import DropPath, to_2tuple, to_ntuple, trunc_normal_, LayerNorm, use_fused_attn
|
from timm.layers import DropPath, to_2tuple, to_ntuple, trunc_normal_, LayerNorm, use_fused_attn
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features import feature_take_indices
|
||||||
from ._manipulate import checkpoint
|
from ._manipulate import checkpoint
|
||||||
from ._registry import register_model, generate_default_cfgs
|
from ._registry import register_model, generate_default_cfgs
|
||||||
|
|
||||||
@ -386,6 +387,62 @@ class PyramidVisionTransformerV2(nn.Module):
|
|||||||
self.global_pool = global_pool
|
self.global_pool = global_pool
|
||||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
|
def forward_intermediates(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
indices: Optional[Union[int, List[int]]] = None,
|
||||||
|
norm: bool = False,
|
||||||
|
stop_early: bool = False,
|
||||||
|
output_fmt: str = 'NCHW',
|
||||||
|
intermediates_only: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input image tensor
|
||||||
|
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||||
|
norm: Apply norm layer to compatible intermediates
|
||||||
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
|
output_fmt: Shape of intermediate feature outputs
|
||||||
|
intermediates_only: Only return intermediate features
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||||
|
intermediates = []
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||||
|
stages = self.stages
|
||||||
|
else:
|
||||||
|
stages = self.stages[:max_index + 1]
|
||||||
|
|
||||||
|
for feat_idx, stage in enumerate(stages):
|
||||||
|
x = stage(x)
|
||||||
|
if feat_idx in take_indices:
|
||||||
|
intermediates.append(x)
|
||||||
|
|
||||||
|
if intermediates_only:
|
||||||
|
return intermediates
|
||||||
|
|
||||||
|
return x, intermediates
|
||||||
|
|
||||||
|
def prune_intermediate_layers(
|
||||||
|
self,
|
||||||
|
indices: Union[int, List[int]] = 1,
|
||||||
|
prune_norm: bool = False,
|
||||||
|
prune_head: bool = True,
|
||||||
|
):
|
||||||
|
""" Prune layers not required for specified intermediates.
|
||||||
|
"""
|
||||||
|
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||||
|
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||||
|
if prune_head:
|
||||||
|
self.reset_classifier(0, '')
|
||||||
|
return take_indices
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
x = self.stages(x)
|
x = self.stages(x)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user