Type fixes, remove old comments
parent
0893f5d296
commit
825edccf19
|
@ -36,10 +36,7 @@ class NaFlexCollator:
|
|||
assert isinstance(batch[0], tuple)
|
||||
batch_size = len(batch)
|
||||
|
||||
# FIXME
|
||||
# get seq len from sampler schedule
|
||||
|
||||
# resize to final size based on seq_len and patchify
|
||||
# Resize to final size based on seq_len and patchify
|
||||
|
||||
# Extract targets
|
||||
targets = torch.tensor([item[1] for item in batch], dtype=torch.int64)
|
||||
|
|
|
@ -575,7 +575,7 @@ class RandomResizedCropToSequence(torch.nn.Module):
|
|||
|
||||
@staticmethod
|
||||
def get_params(
|
||||
img: Union[torch.Tensor, Image],
|
||||
img: torch.Tensor,
|
||||
scale: Tuple[float, float],
|
||||
ratio: Tuple[float, float],
|
||||
crop_attempts: int = 10,
|
||||
|
@ -690,7 +690,7 @@ class RandomResizedCropToSequence(torch.nn.Module):
|
|||
|
||||
return (top, left, crop_h, crop_w), final_size, interpolation
|
||||
|
||||
def forward(self, img: Union[torch.Tensor, Image]) -> torch.Tensor:
|
||||
def forward(self, img: torch.Tensor) -> torch.Tensor:
|
||||
# Sample crop, resize, and interpolation parameters
|
||||
crop_params, final_size, interpolation = self.get_params(
|
||||
img,
|
||||
|
|
Loading…
Reference in New Issue