From 39eb56f875cee75bc8962147ece93053db64918f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 28 Apr 2025 16:48:06 -0700 Subject: [PATCH] Starting to test distributed train, fix issue with batch_size reduce --- train.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index 5f7b7791..efbe2892 100755 --- a/train.py +++ b/train.py @@ -1142,7 +1142,10 @@ def train_one_epoch( if args.distributed: # scale gradient btw distributed ranks, each one can have different batch size - global_batch_size = utils.reduce_tensor(torch.tensor(batch_size, device=device), 1) # SUM + global_batch_size = utils.reduce_tensor( + torch.tensor(batch_size, device=device, dtype=torch.float32), + 1 # SUM + ) dist_scale = args.world_size * batch_size / global_batch_size else: dist_scale = None