Starting to test distributed train, fix issue with batch_size reduce

This commit is contained in:
Ross Wightman 2025-04-28 16:48:06 -07:00
parent ee27b73da4
commit 39eb56f875

View File

@ -1142,7 +1142,10 @@ def train_one_epoch(
if args.distributed:
# scale gradient btw distributed ranks, each one can have different batch size
global_batch_size = utils.reduce_tensor(torch.tensor(batch_size, device=device), 1) # SUM
global_batch_size = utils.reduce_tensor(
torch.tensor(batch_size, device=device, dtype=torch.float32),
1 # SUM
)
dist_scale = args.world_size * batch_size / global_batch_size
else:
dist_scale = None