[Fix] fix issue related bugs (#161)
* [Fix] fix version check of pytorch * [Fix] fix data format in simclrpull/164/head
parent
e96d81e76c
commit
f06c3c3d98
|
@ -103,7 +103,7 @@ def build_dataloader(dataset,
|
|||
worker_init_fn, num_workers=num_workers, rank=rank,
|
||||
seed=seed) if seed is not None else None
|
||||
|
||||
if digit_version(torch.__version__) >= digit_version('1.7.0'):
|
||||
if digit_version(torch.__version__) >= digit_version('1.8.0'):
|
||||
kwargs['persistent_workers'] = persistent_workers
|
||||
|
||||
if kwargs.get('prefetch') is not None:
|
||||
|
|
|
@ -69,7 +69,9 @@ class SimCLR(BaseModel):
|
|||
dict[str, Tensor]: A dictionary of loss components.
|
||||
"""
|
||||
assert isinstance(img, list)
|
||||
img = torch.cat(img)
|
||||
img = torch.stack(img, 1)
|
||||
img = img.reshape(
|
||||
(img.size(0) * 2, img.size(2), img.size(3), img.size(4)))
|
||||
x = self.extract_feat(img) # 2n
|
||||
z = self.neck(x)[0] # (2n)xd
|
||||
z = z / (torch.norm(z, p=2, dim=1, keepdim=True) + 1e-10)
|
||||
|
|
Loading…
Reference in New Issue