diff --git a/train_img_model_cent.py b/train_img_model_cent.py index 5cdc282..66fdb56 100644 --- a/train_img_model_cent.py +++ b/train_img_model_cent.py @@ -166,7 +166,7 @@ def main(): if is_best: best_rank1 = rank1 if use_gpu: - state_dict = model.module.cpu().state_dict() + state_dict = model.module.state_dict() else: state_dict = model.state_dict() save_checkpoint({ diff --git a/train_img_model_xent.py b/train_img_model_xent.py index facdeff..8cc3714 100755 --- a/train_img_model_xent.py +++ b/train_img_model_xent.py @@ -159,7 +159,7 @@ def main(): if is_best: best_rank1 = rank1 if use_gpu: - state_dict = model.module.cpu().state_dict() + state_dict = model.module.state_dict() else: state_dict = model.state_dict() save_checkpoint({ diff --git a/train_img_model_xent_htri.py b/train_img_model_xent_htri.py index a488f7e..a2b5a7c 100755 --- a/train_img_model_xent_htri.py +++ b/train_img_model_xent_htri.py @@ -168,7 +168,7 @@ def main(): if is_best: best_rank1 = rank1 if use_gpu: - state_dict = model.module.cpu().state_dict() + state_dict = model.module.state_dict() else: state_dict = model.state_dict() save_checkpoint({ diff --git a/train_vid_model_xent.py b/train_vid_model_xent.py index eaebe14..91efd95 100755 --- a/train_vid_model_xent.py +++ b/train_vid_model_xent.py @@ -167,7 +167,7 @@ def main(): if is_best: best_rank1 = rank1 if use_gpu: - state_dict = model.module.cpu().state_dict() + state_dict = model.module.state_dict() else: state_dict = model.state_dict() save_checkpoint({ diff --git a/train_vid_model_xent_htri.py b/train_vid_model_xent_htri.py index d5fb229..beae0db 100755 --- a/train_vid_model_xent_htri.py +++ b/train_vid_model_xent_htri.py @@ -176,7 +176,7 @@ def main(): if is_best: best_rank1 = rank1 if use_gpu: - state_dict = model.module.cpu().state_dict() + state_dict = model.module.state_dict() else: state_dict = model.state_dict() save_checkpoint({