Add device arg to validate() calls in train.py

This commit is contained in:
Ross Wightman 2024-02-04 10:14:57 -08:00
parent a08b57e801
commit c7ac37693d

View File

@ -881,6 +881,7 @@ def main():
loader_eval, loader_eval,
validate_loss_fn, validate_loss_fn,
args, args,
device=device,
amp_autocast=amp_autocast, amp_autocast=amp_autocast,
) )
@ -893,6 +894,7 @@ def main():
loader_eval, loader_eval,
validate_loss_fn, validate_loss_fn,
args, args,
device=device,
amp_autocast=amp_autocast, amp_autocast=amp_autocast,
log_suffix=' (EMA)', log_suffix=' (EMA)',
) )