mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Minor updates
This commit is contained in:
parent
cf0c280e1b
commit
a336e5bff3
@ -1,3 +1,4 @@
|
||||
from .cosine_lr import CosineLRScheduler
|
||||
from .plateau_lr import PlateauLRScheduler
|
||||
from .step_lr import StepLRScheduler
|
||||
from .tanh_lr import TanhLRScheduler
|
5
train.py
5
train.py
@ -1,6 +1,4 @@
|
||||
import argparse
|
||||
import csv
|
||||
import os
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime
|
||||
@ -218,7 +216,8 @@ def main():
|
||||
lr_scheduler.step(epoch, eval_metrics['eval_loss'])
|
||||
|
||||
update_summary(
|
||||
epoch, train_metrics, eval_metrics, output_dir, write_header=best_loss is None)
|
||||
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
|
||||
write_header=best_loss is None)
|
||||
|
||||
# save proper checkpoint with eval metric
|
||||
best_loss = saver.save_checkpoint({
|
||||
|
13
utils.py
13
utils.py
@ -5,6 +5,8 @@ import numpy as np
|
||||
import os
|
||||
import shutil
|
||||
import glob
|
||||
import csv
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class CheckpointSaver:
|
||||
@ -137,3 +139,14 @@ def get_outdir(path, *paths, inc=False):
|
||||
outdir = outdir_inc
|
||||
os.makedirs(outdir)
|
||||
return outdir
|
||||
|
||||
|
||||
def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False):
|
||||
rowd = OrderedDict(epoch=epoch)
|
||||
rowd.update(train_metrics)
|
||||
rowd.update(eval_metrics)
|
||||
with open(filename, mode='a') as cf:
|
||||
dw = csv.DictWriter(cf, fieldnames=rowd.keys())
|
||||
if write_header: # first iteration (epoch == 1 can't be used)
|
||||
dw.writeheader()
|
||||
dw.writerow(rowd)
|
||||
|
Loading…
x
Reference in New Issue
Block a user