[Fix] fix issue related bugs (#161)

* [Fix] fix version check of pytorch

* [Fix] fix data format in simclr
pull/164/head
Yixiao Fang 2021-12-28 22:06:19 +08:00 committed by GitHub
parent e96d81e76c
commit f06c3c3d98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 2 deletions

View File

@ -103,7 +103,7 @@ def build_dataloader(dataset,
worker_init_fn, num_workers=num_workers, rank=rank,
seed=seed) if seed is not None else None
if digit_version(torch.__version__) >= digit_version('1.7.0'):
if digit_version(torch.__version__) >= digit_version('1.8.0'):
kwargs['persistent_workers'] = persistent_workers
if kwargs.get('prefetch') is not None:

View File

@ -69,7 +69,9 @@ class SimCLR(BaseModel):
dict[str, Tensor]: A dictionary of loss components.
"""
assert isinstance(img, list)
img = torch.cat(img)
img = torch.stack(img, 1)
img = img.reshape(
(img.size(0) * 2, img.size(2), img.size(3), img.size(4)))
x = self.extract_feat(img) # 2n
z = self.neck(x)[0] # (2n)xd
z = z / (torch.norm(z, p=2, dim=1, keepdim=True) + 1e-10)