From c7ac37693d5eca561b016db750ef0af95d672dbd Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 4 Feb 2024 10:14:57 -0800 Subject: [PATCH] Add device arg to validate() calls in train.py --- train.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/train.py b/train.py index 5e027229..39e889ed 100755 --- a/train.py +++ b/train.py @@ -881,6 +881,7 @@ def main(): loader_eval, validate_loss_fn, args, + device=device, amp_autocast=amp_autocast, ) @@ -893,6 +894,7 @@ def main(): loader_eval, validate_loss_fn, args, + device=device, amp_autocast=amp_autocast, log_suffix=' (EMA)', )