mirror of
https://github.com/facebookresearch/moco-v3.git
synced 2025-06-03 14:59:22 +08:00
add note, and temporarily tensorboard
This commit is contained in:
parent
519aa28fb8
commit
84c7fc0fed
37
main_moco.py
37
main_moco.py
@ -14,9 +14,10 @@ import shutil
|
|||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
# ===== SLURM, to delete =====
|
# ===== to delete =====
|
||||||
import signal
|
import signal
|
||||||
# ===== SLURM, to delete =====
|
from tensorboardX import SummaryWriter
|
||||||
|
# ===== to delete =====
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -38,7 +39,7 @@ import moco.builder
|
|||||||
import moco.loader
|
import moco.loader
|
||||||
import moco.optimizer
|
import moco.optimizer
|
||||||
|
|
||||||
# ===== SLURM, to delete =====
|
# ===== to delete =====
|
||||||
def signalHandler(a, b):
|
def signalHandler(a, b):
|
||||||
if a == signal.SIGUSR1:
|
if a == signal.SIGUSR1:
|
||||||
logger.info('Got SIGUSR1.')
|
logger.info('Got SIGUSR1.')
|
||||||
@ -47,7 +48,7 @@ def signalHandler(a, b):
|
|||||||
|
|
||||||
signal.signal(signal.SIGUSR1, signalHandler)
|
signal.signal(signal.SIGUSR1, signalHandler)
|
||||||
signal.signal(signal.SIGTERM, signalHandler)
|
signal.signal(signal.SIGTERM, signalHandler)
|
||||||
# ===== SLURM, to delete =====
|
# ===== to delete =====
|
||||||
|
|
||||||
torchvision_model_names = sorted(name for name in torchvision_models.__dict__
|
torchvision_model_names = sorted(name for name in torchvision_models.__dict__
|
||||||
if name.islower() and not name.startswith("__")
|
if name.islower() and not name.startswith("__")
|
||||||
@ -126,10 +127,10 @@ parser.add_argument('--optimizer', default='lars', type=str,
|
|||||||
parser.add_argument('--warmup-epochs', default=10, type=int, metavar='N',
|
parser.add_argument('--warmup-epochs', default=10, type=int, metavar='N',
|
||||||
help='number of warmup epochs')
|
help='number of warmup epochs')
|
||||||
|
|
||||||
# ===== OTHERS, to delete =====
|
# ===== to delete =====
|
||||||
parser.add_argument('--checkpoint-folder', default='.', type=str, metavar='PATH',
|
parser.add_argument('--checkpoint-folder', default='.', type=str, metavar='PATH',
|
||||||
help='path to save the checkpoints (default: .)')
|
help='path to save the checkpoints (default: .)')
|
||||||
# ===== OTHERS, to delete =====
|
# ===== to delete =====
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -153,10 +154,10 @@ def main():
|
|||||||
|
|
||||||
args.distributed = args.world_size > 1 or args.multiprocessing_distributed
|
args.distributed = args.world_size > 1 or args.multiprocessing_distributed
|
||||||
|
|
||||||
# ===== PATH, to delete =====
|
# ===== to delete =====
|
||||||
if not os.path.exists(args.checkpoint_folder):
|
if not os.path.exists(args.checkpoint_folder):
|
||||||
os.makedirs(args.checkpoint_folder)
|
os.makedirs(args.checkpoint_folder)
|
||||||
# ===== PATH, to delete =====
|
# ===== to delete =====
|
||||||
|
|
||||||
ngpus_per_node = torch.cuda.device_count()
|
ngpus_per_node = torch.cuda.device_count()
|
||||||
if args.multiprocessing_distributed:
|
if args.multiprocessing_distributed:
|
||||||
@ -250,6 +251,10 @@ def main_worker(gpu, ngpus_per_node, args):
|
|||||||
weight_decay=args.weight_decay)
|
weight_decay=args.weight_decay)
|
||||||
|
|
||||||
scaler = torch.cuda.amp.GradScaler()
|
scaler = torch.cuda.amp.GradScaler()
|
||||||
|
# ===== to delete =====
|
||||||
|
if args.rank == 0:
|
||||||
|
summary_writer = SummaryWriter(logdir=args.checkpoint_folder)
|
||||||
|
# ===== to delete =====
|
||||||
|
|
||||||
# optionally resume from a checkpoint
|
# optionally resume from a checkpoint
|
||||||
if args.resume:
|
if args.resume:
|
||||||
@ -324,7 +329,7 @@ def main_worker(gpu, ngpus_per_node, args):
|
|||||||
adjust_learning_rate(optimizer, init_lr, epoch, args)
|
adjust_learning_rate(optimizer, init_lr, epoch, args)
|
||||||
|
|
||||||
# train for one epoch
|
# train for one epoch
|
||||||
train(train_loader, model, optimizer, scaler, epoch, args)
|
train(train_loader, model, optimizer, scaler, summary_writer, epoch, args)
|
||||||
|
|
||||||
if not args.multiprocessing_distributed or (args.multiprocessing_distributed
|
if not args.multiprocessing_distributed or (args.multiprocessing_distributed
|
||||||
and args.rank == 0): # only the first GPU saves checkpoint
|
and args.rank == 0): # only the first GPU saves checkpoint
|
||||||
@ -336,14 +341,18 @@ def main_worker(gpu, ngpus_per_node, args):
|
|||||||
'scaler': scaler.state_dict(),
|
'scaler': scaler.state_dict(),
|
||||||
}, is_best=False, filename='%s/checkpoint_%04d.pth.tar' % (args.checkpoint_folder, epoch))
|
}, is_best=False, filename='%s/checkpoint_%04d.pth.tar' % (args.checkpoint_folder, epoch))
|
||||||
|
|
||||||
|
# ===== to delete =====
|
||||||
|
if args.rank == 0:
|
||||||
|
summary_writer.close()
|
||||||
|
# ===== to delete =====
|
||||||
|
|
||||||
def train(train_loader, model, optimizer, scaler, epoch, args):
|
def train(train_loader, model, optimizer, scaler, summary_writer, epoch, args):
|
||||||
batch_time = AverageMeter('Time', ':6.3f')
|
batch_time = AverageMeter('Time', ':6.3f')
|
||||||
data_time = AverageMeter('Data', ':6.3f')
|
data_time = AverageMeter('Data', ':6.3f')
|
||||||
losses = AverageMeter('Loss', ':.4e')
|
losses = AverageMeter('Loss', ':.4e')
|
||||||
progress = ProgressMeter(
|
progress = ProgressMeter(
|
||||||
len(train_loader),
|
len(train_loader),
|
||||||
[batch_time, data_time, losses, top1, top5],
|
[batch_time, data_time, losses],
|
||||||
prefix="Epoch: [{}]".format(epoch))
|
prefix="Epoch: [{}]".format(epoch))
|
||||||
|
|
||||||
# switch to train mode
|
# switch to train mode
|
||||||
@ -364,7 +373,11 @@ def train(train_loader, model, optimizer, scaler, epoch, args):
|
|||||||
loss = model(images[0], images[1], moco_m)
|
loss = model(images[0], images[1], moco_m)
|
||||||
|
|
||||||
losses.update(loss.item(), images[0].size(0))
|
losses.update(loss.item(), images[0].size(0))
|
||||||
|
# ===== to delete =====
|
||||||
|
if args.rank == 0:
|
||||||
|
summary_writer.add_scalar("loss", loss.item(), epoch * len(train_loader) + i)
|
||||||
|
# ===== to delete =====
|
||||||
|
|
||||||
# compute gradient and do SGD step
|
# compute gradient and do SGD step
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
|
@ -23,7 +23,6 @@ class MoCo(nn.Module):
|
|||||||
super(MoCo, self).__init__()
|
super(MoCo, self).__init__()
|
||||||
|
|
||||||
self.T = T
|
self.T = T
|
||||||
self.criterion =
|
|
||||||
|
|
||||||
if with_vit:
|
if with_vit:
|
||||||
self._init_encoders_with_vit(base_encoder, dim, mlp_dim)
|
self._init_encoders_with_vit(base_encoder, dim, mlp_dim)
|
||||||
@ -46,6 +45,8 @@ class MoCo(nn.Module):
|
|||||||
mlp.append(nn.BatchNorm1d(dim2))
|
mlp.append(nn.BatchNorm1d(dim2))
|
||||||
mlp.append(nn.ReLU(inplace=True))
|
mlp.append(nn.ReLU(inplace=True))
|
||||||
else:
|
else:
|
||||||
|
# similar to SimCLR: https://github.com/google-research/simclr/blob/master/model_util.py#L157
|
||||||
|
# remove this last BN also works
|
||||||
mlp.append(nn.BatchNorm1d(dim2, affine=False))
|
mlp.append(nn.BatchNorm1d(dim2, affine=False))
|
||||||
|
|
||||||
return nn.Sequential(*mlp)
|
return nn.Sequential(*mlp)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user