mirror of https://github.com/sthalles/SimCLR.git
fix loss function labels
parent
13a7e646e8
commit
55910e6107
2
run.py
2
run.py
|
@ -54,7 +54,7 @@ parser.add_argument('--gpu-index', default=0, type=int, help='Gpu index.')
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
assert args.n_views == 2, "Only two view training is supported."
|
assert args.n_views == 2, "Only two view training is supported. Please use --n-views 2."
|
||||||
# check if gpu training is available
|
# check if gpu training is available
|
||||||
if not args.disable_cuda and torch.cuda.is_available():
|
if not args.disable_cuda and torch.cuda.is_available():
|
||||||
args.device = torch.device('cuda')
|
args.device = torch.device('cuda')
|
||||||
|
|
Loading…
Reference in New Issue