mirror of https://github.com/JDAI-CV/fast-reid.git
Merge branch 'master' of github.com:L1aoXingyu/fast-reid
Conflicts: fastreid/config/defaults.py fastreid/layers/gem_pool.py fastreid/modeling/backbones/resnet.py fastreid/modeling/heads/__init__.py fastreid/modeling/heads/build.py fastreid/modeling/losses/build.py fastreid/modeling/meta_arch/__init__.py fastreid/modeling/meta_arch/abd_network.py fastreid/modeling/meta_arch/baseline.py fastreid/modeling/meta_arch/bdb_network.py fastreid/modeling/meta_arch/mf_network.py projects/StrongBaseline/configs/Base-Strongbaseline.yml projects/StrongBaseline/configs/baseline_dukemtmc.yml projects/StrongBaseline/train_net.pypull/43/head
commit
91dc9bc71f
|
@ -1,49 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Parameter
|
||||
|
||||
|
||||
class AM_softmax(nn.Module):
|
||||
r"""Implement of large margin cosine distance: :
|
||||
Args:
|
||||
in_features: size of each input sample
|
||||
out_features: size of each output sample
|
||||
s: norm of input feature
|
||||
m: margin
|
||||
cos(theta) - m
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, out_features, s=30.0, m=0.40):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.s = s
|
||||
self.m = m
|
||||
self.weight = Parameter(torch.FloatTensor(out_features, in_features))
|
||||
# nn.init.normal_(self.weight, std=0.001)
|
||||
nn.init.xavier_uniform_(self.weight)
|
||||
|
||||
def forward(self, input, label):
|
||||
# --------------------------- cos(theta) & phi(theta) ---------------------------
|
||||
cosine = F.linear(F.normalize(input), F.normalize(self.weight)) # (bs, num_classes)
|
||||
phi = cosine - self.m
|
||||
# phi = cosine
|
||||
# --------------------------- convert label to one-hot ---------------------------
|
||||
one_hot = torch.zeros(cosine.size()).to(label.device)
|
||||
# one_hot = one_hot.cuda() if cosine.is_cuda else one_hot
|
||||
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
|
||||
# -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
|
||||
# you can use torch.where if your torch.__version__ is 0.4
|
||||
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
|
||||
# output *= torch.norm(input, p=2, dim=1, keepdim=True)
|
||||
output *= self.s
|
||||
|
||||
return output
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from torch import nn
|
||||
|
||||
from .build import REID_HEADS_REGISTRY
|
||||
from ..losses import CrossEntropyLoss, TripletLoss
|
||||
from ..model_utils import weights_init_classifier
|
||||
|
||||
|
||||
@REID_HEADS_REGISTRY.register()
|
||||
class StandardHead(nn.Module):
|
||||
|
||||
def __init__(self, cfg, in_feat):
|
||||
super().__init__()
|
||||
self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
|
||||
|
||||
self.classifier = nn.Linear(in_feat, self._num_classes, bias=False)
|
||||
self.classifier.apply(weights_init_classifier)
|
||||
|
||||
def forward(self, features, targets=None):
|
||||
"""
|
||||
See :class:`ReIDHeads.forward`.
|
||||
"""
|
||||
pred_class_logits = self.classifier(features)
|
||||
return pred_class_logits
|
||||
|
||||
@classmethod
|
||||
def losses(cls, cfg, pred_class_logits, global_features, gt_classes, prefix='') -> dict:
|
||||
loss_dict = {}
|
||||
if "CrossEntropyLoss" in cfg.MODEL.LOSSES.NAME and pred_class_logits is not None:
|
||||
loss = CrossEntropyLoss(cfg)(pred_class_logits, gt_classes)
|
||||
loss_dict.update(loss)
|
||||
if "TripletLoss" in cfg.MODEL.LOSSES.NAME and global_features is not None:
|
||||
loss = TripletLoss(cfg)(global_features, gt_classes)
|
||||
loss_dict.update(loss)
|
||||
# rename
|
||||
name_loss_dict = {}
|
||||
for name in loss_dict.keys():
|
||||
name_loss_dict[prefix+name] = loss_dict[name]
|
||||
del loss_dict
|
||||
return name_loss_dict
|
|
@ -1,157 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fastreid.modeling.backbones import *
|
||||
from fastreid.modeling.model_utils import *
|
||||
from fastreid.modeling.heads import *
|
||||
from fastreid.layers import bn_no_bias, GeM
|
||||
|
||||
|
||||
class ClassBlock(nn.Module):
|
||||
"""
|
||||
Define the bottleneck and classifier layer
|
||||
|--bn--|--relu--|--linear--|--classifier--|
|
||||
"""
|
||||
def __init__(self, in_features, num_classes, relu=True, num_bottleneck=512):
|
||||
super().__init__()
|
||||
block1 = []
|
||||
block1 += [nn.BatchNorm1d(in_features)]
|
||||
if relu:
|
||||
block1 += [nn.LeakyReLU(0.1)]
|
||||
block1 += [nn.Linear(in_features, num_bottleneck, bias=False)]
|
||||
self.block1 = nn.Sequential(*block1)
|
||||
|
||||
self.bnneck = bn_no_bias(num_bottleneck)
|
||||
|
||||
# self.classifier = nn.Linear(num_bottleneck, num_classes, bias=False)
|
||||
self.classifier = CircleLoss(num_bottleneck, num_classes, s=256, m=0.25)
|
||||
|
||||
def init_parameters(self):
|
||||
self.block1.apply(weights_init_kaiming)
|
||||
self.bnneck.apply(weights_init_kaiming)
|
||||
self.classifier.apply(weights_init_classifier)
|
||||
|
||||
def forward(self, x, label=None):
|
||||
x = self.block1(x)
|
||||
x = self.bnneck(x)
|
||||
if self.training:
|
||||
cls_out = self.classifier(x, label)
|
||||
return cls_out
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class MSBaseline(nn.Module):
|
||||
def __init__(self,
|
||||
backbone,
|
||||
num_classes,
|
||||
last_stride,
|
||||
with_ibn=False,
|
||||
with_se=False,
|
||||
gcb=None,
|
||||
stage_with_gcb=[False, False, False, False],
|
||||
pretrain=True,
|
||||
model_path=''):
|
||||
super().__init__()
|
||||
if 'resnet' in backbone:
|
||||
self.base = ResNet.from_name(backbone, pretrain, last_stride, with_ibn, with_se, gcb,
|
||||
stage_with_gcb, model_path=model_path)
|
||||
self.in_planes = 2048
|
||||
elif 'osnet' in backbone:
|
||||
if with_ibn:
|
||||
self.base = osnet_ibn_x1_0(pretrained=pretrain)
|
||||
else:
|
||||
self.base = osnet_x1_0(pretrained=pretrain)
|
||||
self.in_planes = 512
|
||||
else:
|
||||
print(f'not support {backbone} backbone')
|
||||
|
||||
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
self.maxpool = nn.AdaptiveMaxPool2d(1)
|
||||
# self.gap = GeM()
|
||||
|
||||
self.num_classes = num_classes
|
||||
|
||||
self.classifier1 = ClassBlock(in_features=1024, num_classes=num_classes)
|
||||
self.classifier2 = ClassBlock(in_features=2048, num_classes=num_classes)
|
||||
|
||||
def forward(self, x, label=None, **kwargs):
|
||||
x4, x3 = self.base(x) # (bs, 2048, 16, 8)
|
||||
x3_max = self.maxpool(x3)
|
||||
x3_max = x3_max.view(x3_max.shape[0], -1) # (bs, 2048)
|
||||
x3_avg = self.avgpool(x3)
|
||||
x3_avg = x3_avg.view(x3_avg.shape[0], -1) # (bs, 2048)
|
||||
x3_feat = x3_max + x3_avg
|
||||
# x3_feat = self.gap(x3) # (bs, 2048, 1, 1)
|
||||
# x3_feat = x3_feat.view(x3_feat.shape[0], -1) # (bs, 2048)
|
||||
x4_max = self.maxpool(x4)
|
||||
x4_max = x4_max.view(x4_max.shape[0], -1) # (bs, 2048)
|
||||
x4_avg = self.avgpool(x4)
|
||||
x4_avg = x4_avg.view(x4_avg.shape[0], -1) # (bs, 2048)
|
||||
x4_feat = x4_max + x4_avg
|
||||
# x4_feat = self.gap(x4) # (bs, 2048, 1, 1)
|
||||
# x4_feat = x4_feat.view(x4_feat.shape[0], -1) # (bs, 2048)
|
||||
|
||||
if self.training:
|
||||
cls_out3 = self.classifier1(x3_feat)
|
||||
cls_out4 = self.classifier2(x4_feat)
|
||||
return cls_out3, cls_out4, x3_max, x3_avg, x4_max, x4_avg
|
||||
else:
|
||||
x3_feat = self.classifier1(x3_feat)
|
||||
x4_feat = self.classifier2(x4_feat)
|
||||
return torch.cat((x3_feat, x4_feat), dim=1)
|
||||
|
||||
def getLoss(self, outputs, labels, **kwargs):
|
||||
cls_out3, cls_out4, x3_max, x3_avg, x4_max, x4_avg = outputs
|
||||
|
||||
tri_loss = (TripletLoss(margin=0.3)(x3_max, labels, normalize_feature=False)[0]
|
||||
+ TripletLoss(margin=0.3)(x3_avg, labels, normalize_feature=False)[0]
|
||||
+ TripletLoss(margin=0.3)(x4_max, labels, normalize_feature=False)[0]
|
||||
+ TripletLoss(margin=0.3)(x4_avg, labels, normalize_feature=False)[0]) / 4
|
||||
softmax_loss = (CrossEntropyLabelSmooth(self.num_classes)(cls_out3, labels) +
|
||||
CrossEntropyLabelSmooth(self.num_classes)(cls_out4, labels)) / 2
|
||||
# softmax_loss = F.cross_entropy(cls_out, labels)
|
||||
|
||||
self.loss = softmax_loss + tri_loss
|
||||
# self.loss = softmax_loss
|
||||
# return {'Softmax': softmax_loss, 'AM_Softmax': AM_softmax, 'Triplet_loss': tri_loss}
|
||||
return {
|
||||
'Softmax': softmax_loss,
|
||||
'Triplet_loss': tri_loss,
|
||||
}
|
||||
|
||||
def load_params_wo_fc(self, state_dict):
|
||||
if 'classifier.weight' in state_dict:
|
||||
state_dict.pop('classifier.weight')
|
||||
if 'amsoftmax.weight' in state_dict:
|
||||
state_dict.pop('amsoftmax.weight')
|
||||
res = self.load_state_dict(state_dict, strict=False)
|
||||
print(f'missing keys {res.missing_keys}')
|
||||
print(f'unexpected keys {res.unexpected_keys}')
|
||||
# assert str(res.missing_keys) == str(['classifier.weight',]), 'issue loading pretrained weights'
|
||||
|
||||
def unfreeze_all_layers(self, ):
|
||||
self.train()
|
||||
for p in self.parameters():
|
||||
p.requires_grad_()
|
||||
|
||||
def unfreeze_specific_layer(self, names):
|
||||
if isinstance(names, str):
|
||||
names = [names]
|
||||
|
||||
for name, module in self.named_children():
|
||||
if name in names:
|
||||
module.train()
|
||||
for p in module.parameters():
|
||||
p.requires_grad_()
|
||||
else:
|
||||
module.eval()
|
||||
for p in module.parameters():
|
||||
p.requires_grad_(False)
|
|
@ -12,28 +12,25 @@ For example, to launch a end-to-end baseline training on market1501 dataset with
|
|||
one should excute:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 python train_net.py --config-file='configs/baseline_ibn_market1501.yml'
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 python train_net.py --config-file='configs/AGW_market1501.yml'
|
||||
```
|
||||
|
||||
## Experimental Results
|
||||
|
||||
### Market1501 dataset
|
||||
|
||||
| Method | Pretrained | Rank@1 | mAP |
|
||||
| :---: | :---: | :---: |:---: |
|
||||
| AGW | ImageNet | 95.2% | 87.9% |
|
||||
| AGW + Ibn-a | ImageNet | 95.1% | 88.2% |
|
||||
| Method | Pretrained | Rank@1 | mAP | mINP |
|
||||
| :---: | :---: | :---: |:---: | :---: |
|
||||
| AGW | ImageNet | 94.9% | 87.4% | 63.1% |
|
||||
|
||||
### DukeMTMC dataset
|
||||
|
||||
| Method | Pretrained | Rank@1 | mAP |
|
||||
| :---: | :---: | :---: |:---: |
|
||||
| AGW | ImageNet | 88.4% | 79.4% |
|
||||
| AGW + Ibn-a | ImageNet | 89.3% | 80.2% |
|
||||
| Method | Pretrained | Rank@1 | mAP | mINP |
|
||||
| :---: | :---: | :---: |:---: | :---: |
|
||||
| AGW | ImageNet | 88.9% | 79.1% | 43.2% |
|
||||
|
||||
### MSMT17 dataset
|
||||
|
||||
| Method | Pretrained | Rank@1 | mAP |
|
||||
| :---: | :---: | :---: |:---: |
|
||||
| AGW | ImageNet | | |
|
||||
| AGW + Ibn-a | ImageNet | |
|
||||
| Method | Pretrained | Rank@1 | mAP | mINP |
|
||||
| :---: | :---: | :---: |:---: | :---: |
|
||||
| AGW | ImageNet | 75.6% | 52.6% | 11.9% |
|
||||
|
|
|
@ -4,6 +4,6 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .gem_pool import GeM_BN_Linear
|
||||
from .gem_pool import GemHead
|
||||
from .resnet_nl import build_resnetNL_backbone
|
||||
from .wr_triplet_loss import WeightedRegularizedTriplet
|
||||
|
|
|
@ -10,6 +10,8 @@ from torch import nn
|
|||
|
||||
from fastreid.modeling.model_utils import weights_init_kaiming, weights_init_classifier
|
||||
from fastreid.modeling.heads import REID_HEADS_REGISTRY
|
||||
from fastreid.layers import bn_no_bias
|
||||
from fastreid.modeling.heads import StandardHead
|
||||
|
||||
|
||||
class GeneralizedMeanPooling(nn.Module):
|
||||
|
@ -53,15 +55,14 @@ class GeneralizedMeanPoolingP(GeneralizedMeanPooling):
|
|||
|
||||
|
||||
@REID_HEADS_REGISTRY.register()
|
||||
class GeM_BN_Linear(nn.Module):
|
||||
class GemHead(nn.Module):
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
|
||||
|
||||
self.gem_pool = GeneralizedMeanPoolingP()
|
||||
self.bnneck = nn.BatchNorm1d(2048)
|
||||
self.bnneck.bias.requires_grad_(False)
|
||||
self.bnneck = bn_no_bias(2048)
|
||||
self.bnneck.apply(weights_init_kaiming)
|
||||
|
||||
self.classifier = nn.Linear(2048, self._num_classes, bias=False)
|
||||
|
@ -76,4 +77,8 @@ class GeM_BN_Linear(nn.Module):
|
|||
return F.normalize(bn_features)
|
||||
|
||||
pred_class_logits = self.classifier(bn_features)
|
||||
return pred_class_logits, global_features, targets,
|
||||
return pred_class_logits, global_features, targets
|
||||
|
||||
@classmethod
|
||||
def losses(cls, cfg, pred_class_logits, global_features, gt_classes):
|
||||
return StandardHead.losses(cfg, pred_class_logits, global_features, gt_classes)
|
|
@ -51,7 +51,7 @@ class ResNetNL(nn.Module):
|
|||
layers = []
|
||||
if planes == 512:
|
||||
with_ibn = False
|
||||
layers.append(block(self.inplanes, planes, with_ibn, stride, downsample))
|
||||
layers.append(block(self.inplanes, planes, with_ibn, stride=stride, downsample=downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, with_ibn))
|
||||
|
@ -142,15 +142,9 @@ def build_resnetNL_backbone(cfg):
|
|||
if not with_ibn:
|
||||
# original resnet
|
||||
state_dict = model_zoo.load_url(model_urls[depth])
|
||||
# remove fully-connected-layers
|
||||
state_dict.pop('fc.weight')
|
||||
state_dict.pop('fc.bias')
|
||||
else:
|
||||
# ibn resnet
|
||||
state_dict = torch.load(pretrain_path)['state_dict']
|
||||
# remove fully-connected-layers
|
||||
state_dict.pop('module.fc.weight')
|
||||
state_dict.pop('module.fc.bias')
|
||||
# remove module in name
|
||||
new_state_dict = {}
|
||||
for k in state_dict:
|
||||
|
@ -160,6 +154,6 @@ def build_resnetNL_backbone(cfg):
|
|||
state_dict = new_state_dict
|
||||
res = model.load_state_dict(state_dict, strict=False)
|
||||
logger = logging.getLogger('fastreid.'+__name__)
|
||||
logger.info('missing keys is {}\n '
|
||||
'unexpected keys is {}'.format(res.missing_keys, res.unexpected_keys))
|
||||
logger.info('missing keys is {}'.format(res.missing_keys))
|
||||
logger.info('unexpected keys is {}'.format(res.unexpected_keys))
|
||||
return model
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from fastreid.modeling.losses.margin_loss import normalize, euclidean_dist
|
||||
from fastreid.modeling.losses import LOSS_REGISTRY
|
||||
|
||||
|
||||
def softmax_weights(dist, mask):
|
||||
|
@ -17,7 +16,6 @@ def softmax_weights(dist, mask):
|
|||
return W
|
||||
|
||||
|
||||
@LOSS_REGISTRY.register()
|
||||
class WeightedRegularizedTriplet(object):
|
||||
|
||||
def __init__(self, cfg):
|
||||
|
|
|
@ -1,10 +1,6 @@
|
|||
_BASE_: "Base-AGW.yml"
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
PRETRAIN: True
|
||||
WITH_IBN: True
|
||||
PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
|
||||
HEADS:
|
||||
NUM_CLASSES: 702
|
||||
|
||||
|
@ -12,4 +8,4 @@ DATASETS:
|
|||
NAMES: ("DukeMTMC",)
|
||||
TESTS: ("DukeMTMC",)
|
||||
|
||||
OUTPUT_DIR: "logs/fastreid_dukemtmc/agw-ibn_net"
|
||||
OUTPUT_DIR: "logs/dukemtmc/agw"
|
||||
|
|
|
@ -1,10 +1,6 @@
|
|||
_BASE_: "Base-AGW.yml"
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
PRETRAIN: True
|
||||
WITH_IBN: True
|
||||
PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
|
||||
HEADS:
|
||||
NUM_CLASSES: 751
|
||||
|
||||
|
@ -12,4 +8,4 @@ DATASETS:
|
|||
NAMES: ("Market1501",)
|
||||
TESTS: ("Market1501",)
|
||||
|
||||
OUTPUT_DIR: "logs/fastreid_market1501/agw-ibn_net"
|
||||
OUTPUT_DIR: "logs/market1501/agw"
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
_BASE_: "Base-AGW.yml"
|
||||
|
||||
MODEL:
|
||||
HEADS:
|
||||
NUM_CLASSES: 1041
|
||||
|
||||
DATASETS:
|
||||
NAMES: ("MSMT17",)
|
||||
TESTS: ("MSMT17",)
|
||||
|
||||
SOLVER:
|
||||
MAX_ITER: 45000
|
||||
STEPS: [20000, 35000]
|
||||
|
||||
LOG_PERIOD: 500
|
||||
CHECKPOINT_PERIOD: 5000
|
||||
|
||||
TEST:
|
||||
EVAL_PERIOD: 5000
|
||||
|
||||
OUTPUT_DIR: "logs/msmt17/agw"
|
||||
|
|
@ -9,11 +9,11 @@ MODEL:
|
|||
PRETRAIN: True
|
||||
|
||||
HEADS:
|
||||
NAME: "GeM_BN_Linear"
|
||||
NAME: "GemHead"
|
||||
NUM_CLASSES: 702
|
||||
|
||||
LOSSES:
|
||||
NAME: ("CrossEntropyLoss", "WeightedRegularizedTriplet")
|
||||
NAME: ("CrossEntropyLoss", "TripletLoss")
|
||||
SMOOTH_ON: True
|
||||
SCALE_CE: 1.0
|
||||
|
||||
|
@ -42,7 +42,7 @@ DATALOADER:
|
|||
NUM_WORKERS: 16
|
||||
|
||||
SOLVER:
|
||||
OPT: "adam"
|
||||
OPT: "Adam"
|
||||
MAX_ITER: 18000
|
||||
BASE_LR: 0.00035
|
||||
WEIGHT_DECAY: 0.0005
|
||||
|
@ -52,10 +52,10 @@ SOLVER:
|
|||
STEPS: [8000, 14000]
|
||||
GAMMA: 0.1
|
||||
|
||||
WARMUP_FACTOR: 0.1
|
||||
WARMUP_FACTOR: 0.01
|
||||
WARMUP_ITERS: 2000
|
||||
|
||||
LOG_PERIOD: 20
|
||||
LOG_PERIOD: 200
|
||||
CHECKPOINT_PERIOD: 6000
|
||||
|
||||
TEST:
|
||||
|
|
|
@ -4,16 +4,25 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append('../..')
|
||||
from fastreid.config import cfg
|
||||
from fastreid.engine import DefaultTrainer, default_argument_parser, default_setup
|
||||
from fastreid.utils.checkpoint import Checkpointer
|
||||
from fastreid.evaluation import ReidEvaluator
|
||||
|
||||
from agwbaseline import *
|
||||
|
||||
|
||||
class Trainer(DefaultTrainer):
|
||||
@classmethod
|
||||
def build_evaluator(cls, cfg, num_query, output_folder=None):
|
||||
if output_folder is None:
|
||||
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
|
||||
return ReidEvaluator(cfg, num_query)
|
||||
|
||||
def setup(args):
|
||||
"""
|
||||
Create configs and perform basic setups.
|
||||
|
@ -29,14 +38,19 @@ def main(args):
|
|||
cfg = setup(args)
|
||||
|
||||
if args.eval_only:
|
||||
model = DefaultTrainer.build_model(cfg)
|
||||
cfg.defrost()
|
||||
cfg.MODEL.BACKBONE.PRETRAIN = False
|
||||
model = Trainer.build_model(cfg)
|
||||
Checkpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
|
||||
cfg.MODEL.WEIGHTS, resume=args.resume
|
||||
)
|
||||
res = DefaultTrainer.test(cfg, model)
|
||||
from torch import nn
|
||||
model = nn.DataParallel(model)
|
||||
model.cuda()
|
||||
res = Trainer.test(cfg, model)
|
||||
return res
|
||||
|
||||
trainer = DefaultTrainer(cfg)
|
||||
trainer = Trainer(cfg)
|
||||
trainer.resume_or_load(resume=args.resume)
|
||||
return trainer.train()
|
||||
|
||||
|
|
Loading…
Reference in New Issue