debug nextvit training (tbc)

pull/488/head
liyubrt 2024-11-30 13:03:57 -08:00
parent d683211ae3
commit 36d54aaad2
4 changed files with 36 additions and 36 deletions

View File

@ -2,6 +2,21 @@
train:
dataset_path: ImageNet:split=TRAIN
batch_size_per_gpu: 64
compute_precision:
teacher:
backbone:
sharding_strategy: NO_SHARD # SHARD_GRAD_OP
dino_head:
sharding_strategy: NO_SHARD
ibot_head:
sharding_strategy: NO_SHARD
student:
backbone:
sharding_strategy: NO_SHARD
dino_head:
sharding_strategy: NO_SHARD
ibot_head:
sharding_strategy: NO_SHARD
student:
arch: vit_large # vit_large, nextvit
arch: nextvit # vit_large, nextvit
block_chunks: 4

View File

@ -10,27 +10,28 @@ class NextVitSmall(torch.nn.Module):
def __init__(self, num_classes=197*1024) -> None:
super().__init__()
# define backbone
self.backbone = _get_nextvit(
model_size="small",
frozen_stages=-1,
norm_eval=False,
with_extra_norm=True,
norm_cfg=dict(type="SyncBN", requires_grad=True),
in_channels=3,
)
# self.proj_head = torch.nn.Sequential(
# torch.nn.Linear(1024, num_classes),
# # define backbone
# self.backbone = _get_nextvit(
# model_size="small",
# frozen_stages=-1,
# norm_eval=False,
# with_extra_norm=True,
# norm_cfg=dict(type="SyncBN", requires_grad=True),
# in_channels=3,
# )
self.conv = torch.nn.Conv2d(in_channels=3, out_channels=1024, kernel_size=32, stride=32, padding=0)
assert num_classes == 197 * 1024
self.num_register_tokens = 1
self.embed_dim = 1024
self.blocks = [0] * 6
self.proj_head = torch.nn.Linear(1024, num_classes)
def forward_backbone(self, x, masks=None):
y = self.backbone(x)
y = functional.adaptive_avg_pool2d(y[-1], (1, 1))
# y = self.backbone(x) # use y[-1]
y = self.conv(x)
print(x.shape, y.shape)
y = functional.adaptive_avg_pool2d(y, (1, 1))
y = torch.flatten(y, 1)
y = self.proj_head(y)

View File

@ -396,8 +396,8 @@ class SSLMetaArch(nn.Module):
def prepare_for_distributed_training(self):
logger.info("DISTRIBUTED FSDP -- preparing model for distributed training")
if has_batchnorms(self.student):
raise NotImplementedError
# if has_batchnorms(self.student):
# raise NotImplementedError
# below will synchronize all student subnetworks across gpus:
for k, v in self.student.items():
self.teacher[k].load_state_dict(self.student[k].state_dict())

View File

@ -22,10 +22,6 @@ from dinov2.utils.utils import CosineScheduler
from dinov2.train.ssl_meta_arch import SSLMetaArch
import sys
sys.path.append('/home/li.yu/code/JupiterCVML/europa/base/src/europa')
from dl.network.nextvit_brt import _get_nextvit
torch.backends.cuda.matmul.allow_tf32 = True # PyTorch 1.12 sets this to False by default
logger = logging.getLogger("dinov2")
@ -44,9 +40,9 @@ def get_args_parser(add_help: bool = True):
parser.add_argument(
"opts",
help="""
Modify config options at the end of the command. For Yacs configs, use
space-separated "PATH.KEY VALUE" pairs.
For python-based LazyConfig, use "path.key=value".
Modify config options at the end of the command. For Yacs configs, use
space-separated "PATH.KEY VALUE" pairs.
For python-based LazyConfig, use "path.key=value".
""".strip(),
default=None,
nargs=argparse.REMAINDER,
@ -307,18 +303,6 @@ def main(args):
model = SSLMetaArch(cfg).to(torch.device("cuda"))
model.prepare_for_distributed_training()
# model = _get_nextvit(
# model_size="small",
# frozen_stages=-1,
# norm_eval=False,
# with_extra_norm=True,
# norm_cfg=dict(type="SyncBN", requires_grad=True),
# in_channels=3,
# )
# print('tunable parameters', sum(p.numel() for p in model.parameters() if p.requires_grad))
# if args.distributed:
# model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
logger.info("Model:\n{}".format(model))
if args.eval_only:
iteration = (