diff --git a/train_img_model_xent.py b/train_img_model_xent.py index a039a01..bc2d038 100644 --- a/train_img_model_xent.py +++ b/train_img_model_xent.py @@ -8,6 +8,7 @@ import os.path as osp import numpy as np import torch +import torch.nn as nn import torch.backends.cudnn as cudnn from torch.utils.data import DataLoader from torch.autograd import Variable