debug nextvit training (tbc)
parent
d683211ae3
commit
36d54aaad2
|
@ -2,6 +2,21 @@
|
|||
train:
|
||||
dataset_path: ImageNet:split=TRAIN
|
||||
batch_size_per_gpu: 64
|
||||
compute_precision:
|
||||
teacher:
|
||||
backbone:
|
||||
sharding_strategy: NO_SHARD # SHARD_GRAD_OP
|
||||
dino_head:
|
||||
sharding_strategy: NO_SHARD
|
||||
ibot_head:
|
||||
sharding_strategy: NO_SHARD
|
||||
student:
|
||||
backbone:
|
||||
sharding_strategy: NO_SHARD
|
||||
dino_head:
|
||||
sharding_strategy: NO_SHARD
|
||||
ibot_head:
|
||||
sharding_strategy: NO_SHARD
|
||||
student:
|
||||
arch: vit_large # vit_large, nextvit
|
||||
arch: nextvit # vit_large, nextvit
|
||||
block_chunks: 4
|
||||
|
|
|
@ -10,27 +10,28 @@ class NextVitSmall(torch.nn.Module):
|
|||
def __init__(self, num_classes=197*1024) -> None:
|
||||
super().__init__()
|
||||
|
||||
# define backbone
|
||||
self.backbone = _get_nextvit(
|
||||
model_size="small",
|
||||
frozen_stages=-1,
|
||||
norm_eval=False,
|
||||
with_extra_norm=True,
|
||||
norm_cfg=dict(type="SyncBN", requires_grad=True),
|
||||
in_channels=3,
|
||||
)
|
||||
|
||||
# self.proj_head = torch.nn.Sequential(
|
||||
# torch.nn.Linear(1024, num_classes),
|
||||
# # define backbone
|
||||
# self.backbone = _get_nextvit(
|
||||
# model_size="small",
|
||||
# frozen_stages=-1,
|
||||
# norm_eval=False,
|
||||
# with_extra_norm=True,
|
||||
# norm_cfg=dict(type="SyncBN", requires_grad=True),
|
||||
# in_channels=3,
|
||||
# )
|
||||
|
||||
self.conv = torch.nn.Conv2d(in_channels=3, out_channels=1024, kernel_size=32, stride=32, padding=0)
|
||||
assert num_classes == 197 * 1024
|
||||
self.num_register_tokens = 1
|
||||
self.embed_dim = 1024
|
||||
self.blocks = [0] * 6
|
||||
self.proj_head = torch.nn.Linear(1024, num_classes)
|
||||
|
||||
def forward_backbone(self, x, masks=None):
|
||||
y = self.backbone(x)
|
||||
y = functional.adaptive_avg_pool2d(y[-1], (1, 1))
|
||||
# y = self.backbone(x) # use y[-1]
|
||||
y = self.conv(x)
|
||||
print(x.shape, y.shape)
|
||||
y = functional.adaptive_avg_pool2d(y, (1, 1))
|
||||
y = torch.flatten(y, 1)
|
||||
y = self.proj_head(y)
|
||||
|
||||
|
|
|
@ -396,8 +396,8 @@ class SSLMetaArch(nn.Module):
|
|||
|
||||
def prepare_for_distributed_training(self):
|
||||
logger.info("DISTRIBUTED FSDP -- preparing model for distributed training")
|
||||
if has_batchnorms(self.student):
|
||||
raise NotImplementedError
|
||||
# if has_batchnorms(self.student):
|
||||
# raise NotImplementedError
|
||||
# below will synchronize all student subnetworks across gpus:
|
||||
for k, v in self.student.items():
|
||||
self.teacher[k].load_state_dict(self.student[k].state_dict())
|
||||
|
|
|
@ -22,10 +22,6 @@ from dinov2.utils.utils import CosineScheduler
|
|||
|
||||
from dinov2.train.ssl_meta_arch import SSLMetaArch
|
||||
|
||||
import sys
|
||||
sys.path.append('/home/li.yu/code/JupiterCVML/europa/base/src/europa')
|
||||
from dl.network.nextvit_brt import _get_nextvit
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True # PyTorch 1.12 sets this to False by default
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
@ -44,9 +40,9 @@ def get_args_parser(add_help: bool = True):
|
|||
parser.add_argument(
|
||||
"opts",
|
||||
help="""
|
||||
Modify config options at the end of the command. For Yacs configs, use
|
||||
space-separated "PATH.KEY VALUE" pairs.
|
||||
For python-based LazyConfig, use "path.key=value".
|
||||
Modify config options at the end of the command. For Yacs configs, use
|
||||
space-separated "PATH.KEY VALUE" pairs.
|
||||
For python-based LazyConfig, use "path.key=value".
|
||||
""".strip(),
|
||||
default=None,
|
||||
nargs=argparse.REMAINDER,
|
||||
|
@ -307,18 +303,6 @@ def main(args):
|
|||
model = SSLMetaArch(cfg).to(torch.device("cuda"))
|
||||
model.prepare_for_distributed_training()
|
||||
|
||||
# model = _get_nextvit(
|
||||
# model_size="small",
|
||||
# frozen_stages=-1,
|
||||
# norm_eval=False,
|
||||
# with_extra_norm=True,
|
||||
# norm_cfg=dict(type="SyncBN", requires_grad=True),
|
||||
# in_channels=3,
|
||||
# )
|
||||
# print('tunable parameters', sum(p.numel() for p in model.parameters() if p.requires_grad))
|
||||
# if args.distributed:
|
||||
# model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
|
||||
|
||||
logger.info("Model:\n{}".format(model))
|
||||
if args.eval_only:
|
||||
iteration = (
|
||||
|
|
Loading…
Reference in New Issue