diff --git a/README.md b/README.md
index b5875b5..f34591f 100644
--- a/README.md
+++ b/README.md
@@ -19,7 +19,7 @@ reid baseline model for exploring softmax and triplet hard loss's influence.
 | loss | rank1 | map |
 | --- | --| ---|
 | triplet hard | 89.9% | 76.8% | 
-| softmax | 87% | 65% |
-|triplet + softmax | 89.7% | 76.2% |
+| softmax | 87.9% | 70.1% |
+|triplet + softmax | 92% | 78.1% |
 
 
diff --git a/bases/base_trainer.py b/bases/base_trainer.py
index acaab41..18d22ef 100644
--- a/bases/base_trainer.py
+++ b/bases/base_trainer.py
@@ -15,12 +15,14 @@ from utils.meters import AverageMeter
 
 
 class BaseTrainer(object):
-    def __init__(self, model, criterion, tb_writer):
+    def __init__(self, opt, model, optimzier, criterion, summary_writer):
+        self.opt = opt
         self.model = model
+        self.optimizer= optimzier
         self.criterion = criterion
-        self.tb_writer = tb_writer
+        self.summary_writer = summary_writer
 
-    def train(self, epoch, data_loader, optimizer, print_freq=1):
+    def train(self, epoch, data_loader):
         self.model.train()
 
         batch_time = AverageMeter()
@@ -32,23 +34,23 @@ class BaseTrainer(object):
             data_time.update(time.time() - start)
 
             # model optimizer
-            inputs, targets = self._parse_data(inputs)
-            loss = self._forward(inputs, targets)
-            optimizer.zero_grad()
-            loss.backward()
-            optimizer.step()
+            self._parse_data(inputs)
+            self._forward()
+            self.optimizer.zero_grad()
+            self._backward()
+            self.optimizer.step()
 
             batch_time.update(time.time() - start)
-            losses.update(loss.item())
+            losses.update(self.loss.item())
 
             # tensorboard
             global_step = epoch * len(data_loader) + i
-            self.tb_writer.add_scalar('loss', loss.item(), global_step)
-            self.tb_writer.add_scalar('lr', optimizer.param_groups[0]['lr'], global_step)
+            self.summary_writer.add_scalar('loss', self.loss.item(), global_step)
+            self.summary_writer.add_scalar('lr', self.optimizer.param_groups[0]['lr'], global_step)
 
             start = time.time()
 
-            if (i + 1) % print_freq == 0:
+            if (i + 1) % self.opt.print_freq == 0:
                 print('Epoch: [{}][{}/{}]\t'
                       'Batch Time {:.3f} ({:.3f})\t'
                       'Data Time {:.3f} ({:.3f})\t'
@@ -57,7 +59,7 @@ class BaseTrainer(object):
                               batch_time.val, batch_time.mean,
                               data_time.val, data_time.mean,
                               losses.val, losses.mean))
-        param_group = optimizer.param_groups
+        param_group = self.optimizer.param_groups
         print('Epoch: [{}]\tEpoch Time {:.3f} s\tLoss {:.3e}\t'
               'Lr {:.2e}'
               .format(epoch, batch_time.sum, losses.mean, param_group[0]['lr']))
@@ -66,5 +68,8 @@ class BaseTrainer(object):
     def _parse_data(self, inputs):
         raise NotImplementedError
 
-    def _forward(self, inputs, targets):
+    def _forward(self):
+        raise NotImplementedError
+
+    def _backward(self):
         raise NotImplementedError
diff --git a/config.py b/config.py
index ebc3999..9f0d0f7 100644
--- a/config.py
+++ b/config.py
@@ -17,8 +17,8 @@ class DefaultConfig(object):
 
     # dataset options
     dataset = 'market'
-    height = 384
-    width = 192
+    height = 256
+    width = 128
 
     # optimization options
     optim = 'Adam'
@@ -32,16 +32,16 @@ class DefaultConfig(object):
     momentum = 0.9
     margin = 0.3
     num_instances = 4
+    num_gpu = 1
 
     # model options
-    model_name = 'ResNetBuilder'
+    model_name = 'softmax'  # softmax, triplet, softmax_triplet
     last_stride = 1
 
     # miscs
     print_freq = 30
     eval_step = 50
     save_dir = '/DATA/pytorch-ckpt/market'
-    gpu = 0, 1
     workers = 10
     start_epoch = 0
 
@@ -56,4 +56,4 @@ class DefaultConfig(object):
                 if not k.startswith('_')}
 
 
-args = DefaultConfig()
+opt = DefaultConfig()
diff --git a/main_reid.py b/main_reid.py
index 4a669e5..6de54d8 100644
--- a/main_reid.py
+++ b/main_reid.py
@@ -16,328 +16,147 @@ from pprint import pprint
 
 import numpy as np
 import torch
-from datasets.samplers import RandomIdentitySampler
 from tensorboardX import SummaryWriter
 from torch import nn
 from torch.backends import cudnn
 from torch.utils.data import DataLoader
 
-from config import args
+from config import opt
 from datasets import data_manager
 from datasets.data_loader import ImageData
-from models import ResNetBuilder
-from trainers import ResNetClsTrainer, ResNetTriTrainer, ResNetClsTriTrainer, ResNetEvaluator
+from datasets.samplers import RandomIdentitySampler
+from models import get_baseline_model
+from trainers import clsTrainer, cls_tripletTrainer, tripletTrainer, ResNetEvaluator
 from utils.loss import TripletLoss
 from utils.serialization import Logger
 from utils.serialization import save_checkpoint
 from utils.transforms import TrainTransform, TestTransform
 
 
-def train_classification(**kwargs):
-    args._parse(kwargs)
+def train(**kwargs):
+    opt._parse(kwargs)
 
     # set random seed and cudnn benchmark
-    torch.manual_seed(args.seed)
+    torch.manual_seed(opt.seed)
 
     use_gpu = torch.cuda.is_available()
-    sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
+    sys.stdout = Logger(osp.join(opt.save_dir, 'log_train.txt'))
 
     print('=========user config==========')
-    pprint(args._state_dict())
+    pprint(opt._state_dict())
     print('============end===============')
 
     if use_gpu:
-        print('currently using GPU {}'.format(args.gpu))
+        print('currently using GPU')
         cudnn.benchmark = True
-        torch.cuda.manual_seed_all(args.seed)
-        os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
+        torch.cuda.manual_seed_all(opt.seed)
     else:
         print('currently using cpu')
 
-    print('initializing dataset {}'.format(args.dataset))
-    dataset = data_manager.init_dataset(name=args.dataset)
+    print('initializing dataset {}'.format(opt.dataset))
+    dataset = data_manager.init_dataset(name=opt.dataset)
 
     pin_memory = True if use_gpu else False
 
-    tb_writer = SummaryWriter(osp.join(args.save_dir, 'tb_log'))
+    summary_writer = SummaryWriter(osp.join(opt.save_dir, 'tensorboard_log'))
 
-    trainloader = DataLoader(
-        ImageData(dataset.train, TrainTransform(args.height, args.width)),
-        batch_size=args.train_batch, shuffle=True, num_workers=args.workers,
-        pin_memory=pin_memory, drop_last=True
-    )
+    if 'triplet' in opt.model_name:
+        trainloader = DataLoader(
+            ImageData(dataset.train, TrainTransform(opt.height, opt.width)),
+            sampler=RandomIdentitySampler(dataset.train, opt.num_instances),
+            batch_size=opt.train_batch, num_workers=opt.workers,
+            pin_memory=pin_memory, drop_last=True
+        )
+    else:
+        trainloader = DataLoader(
+            ImageData(dataset.train, TrainTransform(opt.height, opt.width)),
+            batch_size=opt.train_batch, shuffle=True, num_workers=opt.workers,
+            pin_memory=pin_memory
+        )
 
     queryloader = DataLoader(
-        ImageData(dataset.query, TestTransform(args.height, args.width)),
-        batch_size=args.test_batch, num_workers=args.workers,
+        ImageData(dataset.query, TestTransform(opt.height, opt.width)),
+        batch_size=opt.test_batch, num_workers=opt.workers,
         pin_memory=pin_memory
     )
 
     galleryloader = DataLoader(
-        ImageData(dataset.gallery, TestTransform(args.height, args.width)),
-        batch_size=args.test_batch, num_workers=args.workers,
+        ImageData(dataset.gallery, TestTransform(opt.height, opt.width)),
+        batch_size=opt.test_batch, num_workers=opt.workers,
         pin_memory=pin_memory
     )
 
     print('initializing model ...')
-    model = ResNetBuilder(num_classes=dataset.num_train_pids)
+    if opt.model_name == 'softmax' or opt.model_name == 'softmax_triplet':
+        model, optim_policy = get_baseline_model(dataset.num_train_pids)
+    elif opt.model_name == 'triplet':
+        model, optim_policy = get_baseline_model(num_classes=None)
     print('model size: {:.5f}M'.format(sum(p.numel()
                                            for p in model.parameters()) / 1e6))
 
-    cls_criterion = nn.CrossEntropyLoss()
+    xent_criterion = nn.CrossEntropyLoss()
+    tri_criterion = TripletLoss(opt.margin)
 
-    def xent_criterion(cls_scores, targets):
-        cls_loss = cls_criterion(cls_scores, targets)
+    def cls_criterion(cls_scores, targets):
+        cls_loss = xent_criterion(cls_scores, targets)
         return cls_loss
 
-    # get optimizer
-    optimizer = torch.optim.SGD(
-        model.optim_policy(), lr=args.lr, weight_decay=args.weight_decay,
-        momentum=args.momentum, nesterov=True
-    )
+    def triplet_criterion(feat, targets):
+        triplet_loss, _, _ = tri_criterion(feat, targets)
+        return triplet_loss
 
-    def adjust_lr(optimizer, ep, decay_ep, gamma):
-        decay = gamma ** float(ep // decay_ep)
-        for g in optimizer.param_groups:
-            g['lr'] = args.lr * decay * g.get('lr_multi', 1)
-
-    start_epoch = args.start_epoch
-    if use_gpu:
-        model = nn.DataParallel(model).cuda()
-
-    # get trainer and evaluator
-    reid_trainer = ResNetClsTrainer(model, xent_criterion, tb_writer)
-    reid_evaluator = ResNetEvaluator(model)
-
-    # start training
-    best_rank1 = -np.inf
-    best_epoch = 0
-    for epoch in range(start_epoch, args.max_epoch):
-        if args.step_size > 0:
-            adjust_lr(optimizer, epoch + 1, args.step_size, args.gamma)
-        reid_trainer.train(epoch, trainloader, optimizer, args.print_freq)
-
-        # skip if not save model
-        if args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (epoch + 1) == args.max_epoch:
-            rank1 = reid_evaluator.evaluate(queryloader, galleryloader)
-            is_best = rank1 > best_rank1
-            if is_best:
-                best_rank1 = rank1
-                best_epoch = epoch + 1
-
-            if use_gpu:
-                state_dict = model.module.state_dict()
-            else:
-                state_dict = model.state_dict()
-            save_checkpoint({
-                'state_dict': state_dict,
-                'epoch': epoch + 1,
-            }, is_best=is_best, save_dir=args.save_dir, filename='checkpoint_ep' + str(epoch + 1) + '.pth.tar')
-
-    print(
-        'Best rank-1 {:.1%}, achived at epoch {}'.format(best_rank1, best_epoch))
-
-
-def train_triplet(**kwargs):
-    args._parse(kwargs)
-
-    # set random seed and cudnn benchmark
-    torch.manual_seed(args.seed)
-
-    use_gpu = torch.cuda.is_available()
-    sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
-
-    print('=========user config==========')
-    pprint(args._state_dict())
-    print('============end===============')
-
-    if use_gpu:
-        print('currently using GPU {}'.format(args.gpu))
-        cudnn.benchmark = True
-        torch.cuda.manual_seed_all(args.seed)
-        os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
-    else:
-        print('currently using cpu')
-
-    print('initializing dataset {}'.format(args.dataset))
-    dataset = data_manager.init_dataset(name=args.dataset)
-
-    pin_memory = True if use_gpu else False
-
-    tb_writer = SummaryWriter(osp.join(args.save_dir, 'tb_log'))
-
-    trainloader = DataLoader(
-        ImageData(dataset.train, TrainTransform(args.height, args.width)),
-        sampler=RandomIdentitySampler(dataset.train, args.num_instances),
-        batch_size=args.train_batch, num_workers=args.workers,
-        pin_memory=pin_memory, drop_last=True
-    )
-
-    queryloader = DataLoader(
-        ImageData(dataset.query, TestTransform(args.height, args.width)),
-        batch_size=args.test_batch, num_workers=args.workers,
-        pin_memory=pin_memory
-    )
-
-    galleryloader = DataLoader(
-        ImageData(dataset.gallery, TestTransform(args.height, args.width)),
-        batch_size=args.test_batch, num_workers=args.workers,
-        pin_memory=pin_memory
-    )
-
-    print('initializing model ...')
-    model = ResNetBuilder()
-    print('model size: {:.5f}M'.format(sum(p.numel()
-                                           for p in model.parameters()) / 1e6))
-
-    tri_criterion = TripletLoss(margin=args.margin)
-
-    def tri_hard(feat, targets):
-        tri_loss, _, _ = tri_criterion(feat, targets)
-        return tri_loss
-
-    # get optimizer
-    optimizer = torch.optim.Adam(
-        model.parameters(), lr=args.lr, weight_decay=args.weight_decay
-    )
-
-    def adjust_lr_exp(optimizer, base_lr, ep, total_ep, start_decay_ep, gamma):
-        if ep < start_decay_ep:
-            return
-        lr_decay = gamma ** (float(ep - start_decay_ep) /
-                             (total_ep - start_decay_ep))
-        for g in optimizer.param_groups:
-            g['lr'] = base_lr * lr_decay
-
-    start_epoch = args.start_epoch
-    if use_gpu:
-        model = nn.DataParallel(model).cuda()
-
-    # get trainer and evaluator
-    reid_trainer = ResNetTriTrainer(model, tri_hard, tb_writer)
-    reid_evaluator = ResNetEvaluator(model)
-
-    # start training
-    best_rank1 = -np.inf
-    best_epoch = 0
-    for epoch in range(start_epoch, args.max_epoch):
-        if args.step_size > 0:
-            adjust_lr_exp(optimizer, args.lr, epoch + 1, args.max_epoch, args.step_size, args.gamma)
-        reid_trainer.train(epoch, trainloader, optimizer, args.print_freq)
-
-        # skip if not save model
-        if args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (epoch + 1) == args.max_epoch:
-            rank1 = reid_evaluator.evaluate(queryloader, galleryloader)
-            is_best = rank1 > best_rank1
-            if is_best:
-                best_rank1 = rank1
-                best_epoch = epoch + 1
-
-            if use_gpu:
-                state_dict = model.module.state_dict()
-            else:
-                state_dict = model.state_dict()
-            save_checkpoint({
-                'state_dict': state_dict,
-                'epoch': epoch + 1,
-            }, is_best=is_best, save_dir=args.save_dir, filename='checkpoint_ep' + str(epoch + 1) + '.pth.tar')
-
-    print(
-        'Best rank-1 {:.1%}, achived at epoch {}'.format(best_rank1, best_epoch))
-
-
-def train_cls_triplet(**kwargs):
-    args._parse(kwargs)
-
-    # set random seed and cudnn benchmark
-    torch.manual_seed(args.seed)
-
-    use_gpu = torch.cuda.is_available()
-    sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
-
-    print('=========user config==========')
-    pprint(args._state_dict())
-    print('============end===============')
-
-    if use_gpu:
-        print('currently using GPU {}'.format(args.gpu))
-        cudnn.benchmark = True
-        torch.cuda.manual_seed_all(args.seed)
-        os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
-    else:
-        print('currently using cpu')
-
-    print('initializing dataset {}'.format(args.dataset))
-    dataset = data_manager.init_dataset(name=args.dataset)
-
-    pin_memory = True if use_gpu else False
-
-    tb_writer = SummaryWriter(osp.join(args.save_dir, 'tb_log'))
-
-    trainloader = DataLoader(
-        ImageData(dataset.train, TrainTransform(args.height, args.width)),
-        sampler=RandomIdentitySampler(dataset.train, args.num_instances),
-        batch_size=args.train_batch, num_workers=args.workers,
-        pin_memory=pin_memory, drop_last=True
-    )
-
-    queryloader = DataLoader(
-        ImageData(dataset.query, TestTransform(args.height, args.width)),
-        batch_size=args.test_batch, num_workers=args.workers,
-        pin_memory=pin_memory
-    )
-
-    galleryloader = DataLoader(
-        ImageData(dataset.gallery, TestTransform(args.height, args.width)),
-        batch_size=args.test_batch, num_workers=args.workers,
-        pin_memory=pin_memory
-    )
-
-    print('initializing model ...')
-    model = ResNetBuilder(num_classes=dataset.num_train_pids)
-    print('model size: {:.5f}M'.format(sum(p.numel()
-                                           for p in model.parameters()) / 1e6))
-
-    cls_criterion = nn.CrossEntropyLoss()
-    tri_criterion = TripletLoss(margin=args.margin)
-
-    def xent_tri_criterion(cls_scores, global_feat, targets):
-        cls_loss = cls_criterion(cls_scores, targets)
-        tri_loss, dist_ap, dist_an = tri_criterion(global_feat, targets)
-        loss = cls_loss + tri_loss
+    def cls_tri_criterion(cls_scores, feat, targets):
+        cls_loss = xent_criterion(cls_scores, targets)
+        triplet_loss, _, _ = tri_criterion(feat, targets)
+        loss = cls_loss + triplet_loss
         return loss
 
     # get optimizer
     optimizer = torch.optim.Adam(
-        model.parameters(), lr=args.lr, weight_decay=args.weight_decay
+        optim_policy, lr=opt.lr, weight_decay=opt.weight_decay,
     )
 
-    def adjust_lr_exp(optimizer, base_lr, ep, total_ep, start_decay_ep, gamma):
-        if ep < start_decay_ep:
-            return
-        lr_decay = gamma ** (float(ep - start_decay_ep) /
-                             (total_ep - start_decay_ep))
-        for g in optimizer.param_groups:
-            g['lr'] = base_lr * lr_decay
+    def adjust_lr(optimizer, ep):
+        if ep < 20:
+            lr = 1e-4 * (ep + 1) / 2
+        elif ep < 80:
+            lr = 1e-3 * opt.num_gpu
+        elif 80 <= ep <= 180:
+            lr = 1e-4 * opt.num_gpu
+        elif 180 <= ep <= 300:
+            lr = 1e-5 * opt.num_gpu
+        elif 300 <= ep <= 320:
+            lr = 1e-4 * (ep - 300 + 1) / 2 * opt.num_gpu
+        elif 380 <= ep <= 480:
+            lr = 1e-4 * opt.num_gpu
+        else:
+            lr = 1e-5 * opt.num_gpu
+        for p in optimizer.param_groups:
+            p['lr'] = lr
 
-    start_epoch = args.start_epoch
+    start_epoch = opt.start_epoch
     if use_gpu:
         model = nn.DataParallel(model).cuda()
 
     # get trainer and evaluator
-    reid_trainer = ResNetClsTriTrainer(model, xent_tri_criterion, tb_writer)
+    if opt.model_name == 'softmax':
+        reid_trainer = clsTrainer(opt, model, optimizer, cls_criterion, summary_writer)
+    elif opt.model_name == 'softmax_triplet':
+        reid_trainer = cls_tripletTrainer(opt, model, optimizer, cls_tri_criterion, summary_writer)
+    elif opt.model_name == 'triplet':
+        reid_trainer = tripletTrainer(opt, model, optimizer, triplet_criterion, summary_writer)
     reid_evaluator = ResNetEvaluator(model)
 
     # start training
     best_rank1 = -np.inf
     best_epoch = 0
-    for epoch in range(start_epoch, args.max_epoch):
-        if args.step_size > 0:
-            adjust_lr_exp(optimizer, args.lr, epoch + 1, args.max_epoch, args.step_size, args.gamma)
-        reid_trainer.train(epoch, trainloader, optimizer, args.print_freq)
+    for epoch in range(start_epoch, opt.max_epoch):
+        if opt.step_size > 0:
+            adjust_lr(optimizer, epoch + 1)
+        reid_trainer.train(epoch, trainloader)
 
         # skip if not save model
-        if args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (epoch + 1) == args.max_epoch:
+        if opt.eval_step > 0 and (epoch + 1) % opt.eval_step == 0 or (epoch + 1) == opt.max_epoch:
             rank1 = reid_evaluator.evaluate(queryloader, galleryloader)
             is_best = rank1 > best_rank1
             if is_best:
@@ -351,49 +170,49 @@ def train_cls_triplet(**kwargs):
             save_checkpoint({
                 'state_dict': state_dict,
                 'epoch': epoch + 1,
-            }, is_best=is_best, save_dir=args.save_dir, filename='checkpoint_ep' + str(epoch + 1) + '.pth.tar')
+            }, is_best=is_best, save_dir=opt.save_dir, filename='checkpoint_ep' + str(epoch + 1) + '.pth.tar')
 
     print(
         'Best rank-1 {:.1%}, achived at epoch {}'.format(best_rank1, best_epoch))
 
 
 def test(**kwargs):
-    args._parse(kwargs)
+    opt._parse(kwargs)
 
     # set random seed and cudnn benchmark
-    torch.manual_seed(args.seed)
+    torch.manual_seed(opt.seed)
 
     use_gpu = torch.cuda.is_available()
-    sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
+    sys.stdout = Logger(osp.join(opt.save_dir, 'log_train.txt'))
 
     if use_gpu:
-        print('currently using GPU {}'.format(args.gpu))
+        print('currently using GPU {}'.format(opt.gpu))
         cudnn.benchmark = True
-        torch.cuda.manual_seed_all(args.seed)
-        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
+        torch.cuda.manual_seed_all(opt.seed)
+        os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu
     else:
         print('currently using cpu')
 
-    print('initializing dataset {}'.format(args.dataset))
-    dataset = data_manager.init_dataset(name=args.dataset)
+    print('initializing dataset {}'.format(opt.dataset))
+    dataset = data_manager.init_dataset(name=opt.dataset)
 
     pin_memory = True if use_gpu else False
 
     queryloader = DataLoader(
-        ImageData(dataset.query, TestTransform(args.height, args.width)),
-        batch_size=args.test_batch, num_workers=args.workers,
+        ImageData(dataset.query, TestTransform(opt.height, opt.width)),
+        batch_size=opt.test_batch, num_workers=opt.workers,
         pin_memory=pin_memory
     )
 
     galleryloader = DataLoader(
-        ImageData(dataset.gallery, TestTransform(args.height, args.width)),
-        batch_size=args.test_batch, num_workers=args.workers,
+        ImageData(dataset.gallery, TestTransform(opt.height, opt.width)),
+        batch_size=opt.test_batch, num_workers=opt.workers,
         pin_memory=pin_memory
     )
 
     print('loading model ...')
-    model = ResNetBuilder(num_classes=dataset.num_train_pids)
-    # ckpt = torch.load(args.load_model)
+    model, optim_policy = get_baseline_model(dataset.num_train_pids)
+    # ckpt = torch.load(opt.load_model)
     # model.load_state_dict(ckpt['state_dict'])
     print('model size: {:.5f}M'.format(sum(p.numel()
                                            for p in model.parameters()) / 1e6))
diff --git a/models/__init__.py b/models/__init__.py
index 68e22d9..e6d762c 100644
--- a/models/__init__.py
+++ b/models/__init__.py
@@ -9,4 +9,5 @@ from __future__ import division
 from __future__ import print_function
 from __future__ import unicode_literals
 
-from .resnet_reid import ResNetBuilder
+from .baseline_model import ResNetBuilder
+from .networks import get_baseline_model
\ No newline at end of file
diff --git a/models/resnet_reid.py b/models/baseline_model.py
similarity index 70%
rename from models/resnet_reid.py
rename to models/baseline_model.py
index f9774c4..f1b2f7c 100644
--- a/models/resnet_reid.py
+++ b/models/baseline_model.py
@@ -9,6 +9,8 @@ from __future__ import division
 from __future__ import print_function
 from __future__ import unicode_literals
 
+import itertools
+
 import torch.nn.functional as F
 from torch import nn
 
@@ -22,7 +24,7 @@ def weights_init_kaiming(m):
         nn.init.constant_(m.bias, 0.0)
     elif classname.find('Conv') != -1:
         nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
-        if hasattr(m, 'bias'):
+        if m.bias is not None:
             nn.init.constant_(m.bias, 0.0)
     elif classname.find('BatchNorm') != -1:
         if m.affine:
@@ -58,24 +60,35 @@ class ResNetBuilder(nn.Module):
             self.classifier.apply(weights_init_classifier)
 
     def forward(self, x):
-        feat = self.base(x)
-        global_feat = F.avg_pool2d(feat, feat.shape[2:])  # (b, 2048, 1, 1)
+        global_feat = self.base(x)
+        global_feat = F.avg_pool2d(global_feat, global_feat.shape[2:])  # (b, 2048, 1, 1)
         global_feat = global_feat.view(global_feat.shape[0], -1)
         if self.training and self.num_classes is not None:
-            global_feat = self.bottleneck(global_feat)
-            cls_score = self.classifier(global_feat)
+            feat = self.bottleneck(global_feat)
+            cls_score = self.classifier(feat)
             return cls_score, global_feat
         else:
             return global_feat
 
-    def optim_policy(self):
+    def get_optim_policy(self):
         base_param_group = self.base.parameters()
-        other_param_group = list()
-        other_param_group.extend(list(self.bottleneck.parameters()))
-        other_param_group.extend(list(self.classifier.parameters()))
-        return [
-            {'params': base_param_group, 'lr_multi': 0.1},
-            {'params': other_param_group}
-        ]
+        if self.num_classes is not None:
+            add_param_group = itertools.chain(self.bottleneck.parameters(), self.classifier.parameters())
+            return [
+                {'params': base_param_group},
+                {'params': add_param_group}
+            ]
+        else:
+            return [
+                {'params': base_param_group}
+            ]
 
 
+if __name__ == '__main__':
+    net = ResNetBuilder(None)
+    net.cuda()
+    import torch as th
+    x = th.ones(2, 3, 256, 128).cuda()
+    y = net(x)
+    from IPython import embed
+    embed()
\ No newline at end of file
diff --git a/models/networks.py b/models/networks.py
new file mode 100644
index 0000000..892be74
--- /dev/null
+++ b/models/networks.py
@@ -0,0 +1,18 @@
+# encoding: utf-8
+"""
+@author:  liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from .baseline_model import ResNetBuilder
+
+
+def get_baseline_model(num_classes, last_stride=1, model_path='/DATA/model_zoo/resnet50-19c8e357.pth'):
+    model = ResNetBuilder(num_classes, last_stride, model_path)
+    optim_policy = model.get_optim_policy()
+    return model, optim_policy
diff --git a/models/resnet.py b/models/resnet.py
index f63fa5f..fa6ca5d 100644
--- a/models/resnet.py
+++ b/models/resnet.py
@@ -11,7 +11,7 @@ from __future__ import unicode_literals
 
 import math
 
-import torch
+import torch as th
 from torch import nn
 
 
@@ -98,7 +98,7 @@ class ResNet(nn.Module):
         return x
 
     def load_param(self, model_path):
-        param_dict = torch.load(model_path)
+        param_dict = th.load(model_path)
         for i in param_dict:
             if 'fc' in i:
                 continue
@@ -117,5 +117,6 @@ class ResNet(nn.Module):
 if __name__ == "__main__":
     net = ResNet(last_stride=2)
     import torch
+
     x = net(torch.zeros(1, 3, 256, 128))
     print(x.shape)
diff --git a/scripts/train_classification.sh b/scripts/train_classification.sh
deleted file mode 100644
index ec7467d..0000000
--- a/scripts/train_classification.sh
+++ /dev/null
@@ -1,3 +0,0 @@
-#!/usr/bin/env bash
-
-python3 ../main_sk_image_model.py train --save_dir='/DATA/pytorch-ckpt/market1501'
diff --git a/scripts/train_softmax.sh b/scripts/train_softmax.sh
new file mode 100644
index 0000000..eb7edbf
--- /dev/null
+++ b/scripts/train_softmax.sh
@@ -0,0 +1,4 @@
+#!/usr/bin/env bash
+
+CUDA_VISIBLE_DEVICES=0 python3 main_reid.py train --save_dir='/DATA/pytorch-ckpt/market1501_softmax' --max_epoch=400 \
+--eval_step=50 --model_name='softmax'
diff --git a/scripts/train_triplet.sh b/scripts/train_triplet.sh
new file mode 100644
index 0000000..26b5480
--- /dev/null
+++ b/scripts/train_triplet.sh
@@ -0,0 +1,4 @@
+#!/usr/bin/env bash
+
+CUDA_VISIBLE_DEVICES=0 python3 main_reid.py train --save_dir='/DATA/pytorch-ckpt/market1501_triplet' --max_epoch=400 \
+--eval_step=50 --model_name='triplet'
\ No newline at end of file
diff --git a/scripts/train_triplet_softmax.sh b/scripts/train_triplet_softmax.sh
new file mode 100644
index 0000000..84cfefd
--- /dev/null
+++ b/scripts/train_triplet_softmax.sh
@@ -0,0 +1,4 @@
+#!/usr/bin/env bash
+
+CUDA_VISIBLE_DEVICES=0 python3 main_reid.py train --save_dir='/DATA/pytorch-ckpt/market1501_softmax_triplet' \
+--max_epoch=400 --eval_step=50 --model_name='softmax_triplet'
diff --git a/trainers/__init__.py b/trainers/__init__.py
index 4af45ee..89328b4 100644
--- a/trainers/__init__.py
+++ b/trainers/__init__.py
@@ -10,4 +10,4 @@ from __future__ import print_function
 from __future__ import unicode_literals
 
 from .evaluator import ResNetEvaluator
-from .trainer import ResNetClsTrainer, ResNetTriTrainer, ResNetClsTriTrainer
+from .trainer import cls_tripletTrainer, tripletTrainer, clsTrainer
diff --git a/trainers/trainer.py b/trainers/trainer.py
index d1b68e0..2ed8149 100644
--- a/trainers/trainer.py
+++ b/trainers/trainer.py
@@ -12,43 +12,52 @@ from __future__ import unicode_literals
 from bases.base_trainer import BaseTrainer
 
 
-class ResNetClsTrainer(BaseTrainer):
-    def __init__(self, model, criterion, tb_writer):
-        super().__init__(model, criterion, tb_writer)
+class clsTrainer(BaseTrainer):
+    def __init__(self, opt, model, optimizer, criterion, summary_writer):
+        super().__init__(opt, model, optimizer, criterion, summary_writer)
 
     def _parse_data(self, inputs):
         imgs, pids, _ = inputs
-        return imgs.cuda(), pids.cuda()
+        self.data = imgs.cuda()
+        self.target = pids.cuda()
 
-    def _forward(self, inputs, targets):
-        cls_score, _ = self.model(inputs)
-        loss = self.criterion(cls_score, targets)
-        return loss
+    def _forward(self):
+        score, _ = self.model(self.data)
+        self.loss = self.criterion(score, self.target)
+
+    def _backward(self):
+        self.loss.backward()
 
 
-class ResNetTriTrainer(BaseTrainer):
-    def __init__(self, model, criterion, tb_writer):
-        super().__init__(model, criterion, tb_writer)
+class tripletTrainer(BaseTrainer):
+    def __init__(self, opt, model, optimizer, criterion, summary_writer):
+        super().__init__(opt, model, optimizer, criterion, summary_writer)
 
     def _parse_data(self, inputs):
         imgs, pids, _ = inputs
-        return imgs.cuda(), pids.cuda()
+        self.data = imgs.cuda()
+        self.target = pids.cuda()
 
-    def _forward(self, inputs, targets):
-        feat = self.model(inputs)
-        loss = self.criterion(feat, targets)
-        return loss
+    def _forward(self):
+        feat = self.model(self.data)
+        self.loss = self.criterion(feat, self.target)
+
+    def _backward(self):
+        self.loss.backward()
 
 
-class ResNetClsTriTrainer(BaseTrainer):
-    def __init__(self, model, criterion, tb_writer):
-        super().__init__(model, criterion, tb_writer)
+class cls_tripletTrainer(BaseTrainer):
+    def __init__(self, opt, model, optimizer, criterion, summary_writer):
+        super().__init__(opt, model, optimizer, criterion, summary_writer)
 
     def _parse_data(self, inputs):
         imgs, pids, _ = inputs
-        return imgs.cuda(), pids.cuda()
+        self.data = imgs.cuda()
+        self.target = pids.cuda()
 
-    def _forward(self, inputs, targets):
-        cls_score, feat = self.model(inputs)
-        loss = self.criterion(cls_score, feat, targets)
-        return loss
\ No newline at end of file
+    def _forward(self):
+        score, feat = self.model(self.data)
+        self.loss = self.criterion(score, feat, self.target)
+
+    def _backward(self):
+        self.loss.backward()