mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Minor updates
This commit is contained in:
parent
cf0c280e1b
commit
a336e5bff3
@ -1,3 +1,4 @@
|
|||||||
from .cosine_lr import CosineLRScheduler
|
from .cosine_lr import CosineLRScheduler
|
||||||
from .plateau_lr import PlateauLRScheduler
|
from .plateau_lr import PlateauLRScheduler
|
||||||
from .step_lr import StepLRScheduler
|
from .step_lr import StepLRScheduler
|
||||||
|
from .tanh_lr import TanhLRScheduler
|
5
train.py
5
train.py
@ -1,6 +1,4 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import csv
|
|
||||||
import os
|
|
||||||
import time
|
import time
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@ -218,7 +216,8 @@ def main():
|
|||||||
lr_scheduler.step(epoch, eval_metrics['eval_loss'])
|
lr_scheduler.step(epoch, eval_metrics['eval_loss'])
|
||||||
|
|
||||||
update_summary(
|
update_summary(
|
||||||
epoch, train_metrics, eval_metrics, output_dir, write_header=best_loss is None)
|
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
|
||||||
|
write_header=best_loss is None)
|
||||||
|
|
||||||
# save proper checkpoint with eval metric
|
# save proper checkpoint with eval metric
|
||||||
best_loss = saver.save_checkpoint({
|
best_loss = saver.save_checkpoint({
|
||||||
|
13
utils.py
13
utils.py
@ -5,6 +5,8 @@ import numpy as np
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import glob
|
import glob
|
||||||
|
import csv
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
|
||||||
class CheckpointSaver:
|
class CheckpointSaver:
|
||||||
@ -137,3 +139,14 @@ def get_outdir(path, *paths, inc=False):
|
|||||||
outdir = outdir_inc
|
outdir = outdir_inc
|
||||||
os.makedirs(outdir)
|
os.makedirs(outdir)
|
||||||
return outdir
|
return outdir
|
||||||
|
|
||||||
|
|
||||||
|
def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False):
|
||||||
|
rowd = OrderedDict(epoch=epoch)
|
||||||
|
rowd.update(train_metrics)
|
||||||
|
rowd.update(eval_metrics)
|
||||||
|
with open(filename, mode='a') as cf:
|
||||||
|
dw = csv.DictWriter(cf, fieldnames=rowd.keys())
|
||||||
|
if write_header: # first iteration (epoch == 1 can't be used)
|
||||||
|
dw.writeheader()
|
||||||
|
dw.writerow(rowd)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user