mirror of https://github.com/sthalles/SimCLR.git
fix loss function labels
parent
4c056cb919
commit
13a7e646e8
2
run.py
2
run.py
|
@ -54,7 +54,7 @@ parser.add_argument('--gpu-index', default=0, type=int, help='Gpu index.')
|
|||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.n_views == 2, "Only two view training is supported."
|
||||
# check if gpu training is available
|
||||
if not args.disable_cuda and torch.cuda.is_available():
|
||||
args.device = torch.device('cuda')
|
||||
|
|
|
@ -55,9 +55,6 @@ class SimCLR(object):
|
|||
# select and combine multiple positives
|
||||
positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
|
||||
|
||||
# if there is more than one potive (n_views >= 2) combine the multiple positives
|
||||
positives = positives.mean(dim=1).unsqueeze(1)
|
||||
|
||||
# select only the negatives the negatives
|
||||
negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
|
||||
|
||||
|
|
Loading…
Reference in New Issue