Add device arg to validate() calls in train.py

This commit is contained in:
Ross Wightman 2024-02-04 10:14:57 -08:00
parent a08b57e801
commit c7ac37693d

View File

@ -881,6 +881,7 @@ def main():
loader_eval,
validate_loss_fn,
args,
device=device,
amp_autocast=amp_autocast,
)
@ -893,6 +894,7 @@ def main():
loader_eval,
validate_loss_fn,
args,
device=device,
amp_autocast=amp_autocast,
log_suffix=' (EMA)',
)