mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
modify eval and vgg
This commit is contained in:
parent
a68d90c53b
commit
b06bd21e94
@ -106,7 +106,7 @@ class VGGNet(fluid.dygraph.Layer):
|
|||||||
x = self._conv_block_4(x)
|
x = self._conv_block_4(x)
|
||||||
x = self._conv_block_5(x)
|
x = self._conv_block_5(x)
|
||||||
|
|
||||||
x = fluid.layers.flatten(x, axis=0)
|
x = fluid.layers.reshape(x, [0,-1])
|
||||||
x = self._fc1(x)
|
x = self._fc1(x)
|
||||||
x = self._drop(x)
|
x = self._drop(x)
|
||||||
x = self._fc2(x)
|
x = self._fc2(x)
|
||||||
|
@ -19,13 +19,10 @@ from __future__ import print_function
|
|||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import paddle.fluid as fluid
|
|
||||||
|
|
||||||
import program
|
|
||||||
|
|
||||||
from ppcls.data import Reader
|
from ppcls.data import Reader
|
||||||
from ppcls.utils.config import get_config
|
from ppcls.utils.config import get_config
|
||||||
from ppcls.utils.save_load import init_model
|
from ppcls.utils.save_load import init_model
|
||||||
|
from ppcls.utils import logger
|
||||||
|
|
||||||
from paddle.fluid.incubate.fleet.collective import fleet
|
from paddle.fluid.incubate.fleet.collective import fleet
|
||||||
from paddle.fluid.incubate.fleet.base import role_maker
|
from paddle.fluid.incubate.fleet.base import role_maker
|
||||||
@ -45,37 +42,25 @@ def parse_args():
|
|||||||
action='append',
|
action='append',
|
||||||
default=[],
|
default=[],
|
||||||
help='config options to be overridden')
|
help='config options to be overridden')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
|
# assign the place
|
||||||
fleet.init(role)
|
gpu_id = fluid.dygraph.parallel.Env().dev_id
|
||||||
|
|
||||||
config = get_config(args.config, overrides=args.override, show=True)
|
|
||||||
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
|
|
||||||
place = fluid.CUDAPlace(gpu_id)
|
place = fluid.CUDAPlace(gpu_id)
|
||||||
|
with fluid.dygraph.guard(place):
|
||||||
startup_prog = fluid.Program()
|
pre_weights_dict = fluid.dygraph.load_dygraph(config.pretrained_model)[0]
|
||||||
valid_prog = fluid.Program()
|
strategy = fluid.dygraph.parallel.prepare_context()
|
||||||
valid_dataloader, valid_fetchs = program.build(
|
net = program.create_model(config.ARCHITECTURE, config.classes_num)
|
||||||
config, valid_prog, startup_prog, is_train=False)
|
net = fluid.dygraph.parallel.DataParallel(net, strategy)
|
||||||
valid_prog = valid_prog.clone(for_test=True)
|
net.set_dict(pre_weights_dict)
|
||||||
|
valid_dataloader = program.create_dataloader()
|
||||||
exe = fluid.Executor(place)
|
valid_reader = Reader(config, 'valid')()
|
||||||
exe.run(startup_prog)
|
valid_dataloader.set_sample_list_generator(valid_reader, place)
|
||||||
|
net.eval()
|
||||||
init_model(config, valid_prog, exe)
|
top1_acc = program.run(valid_dataloader, config, net, None, 0, 'valid')
|
||||||
|
|
||||||
valid_reader = Reader(config, 'valid')()
|
|
||||||
valid_dataloader.set_sample_list_generator(valid_reader, place)
|
|
||||||
|
|
||||||
compiled_valid_prog = program.compile(config, valid_prog)
|
|
||||||
program.run(valid_dataloader, exe, compiled_valid_prog, valid_fetchs, -1,
|
|
||||||
'eval')
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user