mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
Merge pull request #265 from littletomatodonkey/dyg/fix_pypath
remove python path config and support cpu train/val/infer
This commit is contained in:
commit
b0b9ca0d65
@ -12,9 +12,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
from ppcls import model_zoo
|
from ppcls import model_zoo
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
sys.path.append(__dir__)
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
|
@ -13,19 +13,22 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
import program
|
||||||
|
from ppcls.utils import logger
|
||||||
|
from ppcls.utils.save_load import init_model
|
||||||
|
from ppcls.utils.config import get_config
|
||||||
|
from ppcls.data import Reader
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
import paddle
|
||||||
|
import argparse
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import argparse
|
import sys
|
||||||
|
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||||
from ppcls.data import Reader
|
sys.path.append(__dir__)
|
||||||
from ppcls.utils.config import get_config
|
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||||
from ppcls.utils.save_load import init_model
|
|
||||||
from ppcls.utils import logger
|
|
||||||
|
|
||||||
from paddle.fluid.incubate.fleet.collective import fleet
|
|
||||||
from paddle.fluid.incubate.fleet.base import role_maker
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
@ -47,21 +50,26 @@ def parse_args():
|
|||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
# assign the place
|
config = get_config(args.config, overrides=args.override, show=True)
|
||||||
gpu_id = fluid.dygraph.parallel.Env().dev_id
|
# assign place
|
||||||
place = fluid.CUDAPlace(gpu_id)
|
use_gpu = config.get("use_gpu", True)
|
||||||
|
if use_gpu:
|
||||||
|
gpu_id = fluid.dygraph.ParallelEnv().dev_id
|
||||||
|
place = fluid.CUDAPlace(gpu_id)
|
||||||
|
else:
|
||||||
|
place = fluid.CPUPlace()
|
||||||
with fluid.dygraph.guard(place):
|
with fluid.dygraph.guard(place):
|
||||||
pre_weights_dict = fluid.dygraph.load_dygraph(config.pretrained_model)[0]
|
|
||||||
strategy = fluid.dygraph.parallel.prepare_context()
|
strategy = fluid.dygraph.parallel.prepare_context()
|
||||||
net = program.create_model(config.ARCHITECTURE, config.classes_num)
|
net = program.create_model(config.ARCHITECTURE, config.classes_num)
|
||||||
net = fluid.dygraph.parallel.DataParallel(net, strategy)
|
net = fluid.dygraph.parallel.DataParallel(net, strategy)
|
||||||
net.set_dict(pre_weights_dict)
|
init_model(config, net, optimizer=None)
|
||||||
valid_dataloader = program.create_dataloader()
|
valid_dataloader = program.create_dataloader()
|
||||||
valid_reader = Reader(config, 'valid')()
|
valid_reader = Reader(config, 'valid')()
|
||||||
valid_dataloader.set_sample_list_generator(valid_reader, place)
|
valid_dataloader.set_sample_list_generator(valid_reader, place)
|
||||||
net.eval()
|
net.eval()
|
||||||
top1_acc = program.run(valid_dataloader, config, net, None, 0, 'valid')
|
top1_acc = program.run(valid_dataloader, config, net, None, 0, 'valid')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
|
||||||
|
|
||||||
python -m paddle.distributed.launch \
|
python -m paddle.distributed.launch \
|
||||||
--selected_gpus="0" \
|
--selected_gpus="0" \
|
||||||
tools/eval.py \
|
tools/eval.py \
|
||||||
-c ./configs/eval.yaml
|
-c ./configs/eval.yaml \
|
||||||
|
-o load_static_weights=True \
|
||||||
|
-o use_gpu=False
|
||||||
|
@ -12,13 +12,17 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import utils
|
|
||||||
import argparse
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import paddle.fluid as fluid
|
|
||||||
from ppcls.modeling import architectures
|
|
||||||
from ppcls.utils.save_load import load_dygraph_pretrain
|
from ppcls.utils.save_load import load_dygraph_pretrain
|
||||||
|
from ppcls.modeling import architectures
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
import numpy as np
|
||||||
|
import argparse
|
||||||
|
import utils
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
sys.path.append(__dir__)
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
@ -66,6 +70,23 @@ def postprocess(outputs, topk=5):
|
|||||||
return zip(index, prob[index])
|
return zip(index, prob[index])
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_list(img_file):
|
||||||
|
imgs_lists = []
|
||||||
|
if img_file is None or not os.path.exists(img_file):
|
||||||
|
raise Exception("not found any img file in {}".format(img_file))
|
||||||
|
|
||||||
|
img_end = ['jpg', 'png', 'jpeg', 'JPEG', 'JPG', 'bmp']
|
||||||
|
if os.path.isfile(img_file) and img_file.split('.')[-1] in img_end:
|
||||||
|
imgs_lists.append(img_file)
|
||||||
|
elif os.path.isdir(img_file):
|
||||||
|
for single_file in os.listdir(img_file):
|
||||||
|
if single_file.split('.')[-1] in img_end:
|
||||||
|
imgs_lists.append(os.path.join(img_file, single_file))
|
||||||
|
if len(imgs_lists) == 0:
|
||||||
|
raise Exception("not found any img file in {}".format(img_file))
|
||||||
|
return imgs_lists
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
operators = create_operators()
|
operators = create_operators()
|
||||||
@ -78,22 +99,25 @@ def main():
|
|||||||
|
|
||||||
with fluid.dygraph.guard(place):
|
with fluid.dygraph.guard(place):
|
||||||
net = architectures.__dict__[args.model]()
|
net = architectures.__dict__[args.model]()
|
||||||
data = preprocess(args.image_file, operators)
|
|
||||||
data = np.expand_dims(data, axis=0)
|
|
||||||
data = fluid.dygraph.to_variable(data)
|
|
||||||
load_dygraph_pretrain(net, args.pretrained_model,
|
load_dygraph_pretrain(net, args.pretrained_model,
|
||||||
args.load_static_weights)
|
args.load_static_weights)
|
||||||
net.eval()
|
image_list = get_image_list(args.image_file)
|
||||||
outputs = net(data)
|
for idx, filename in enumerate(image_list):
|
||||||
outputs = fluid.layers.softmax(outputs)
|
data = preprocess(filename, operators)
|
||||||
outputs = outputs.numpy()
|
data = np.expand_dims(data, axis=0)
|
||||||
|
data = fluid.dygraph.to_variable(data)
|
||||||
|
net.eval()
|
||||||
|
outputs = net(data)
|
||||||
|
outputs = fluid.layers.softmax(outputs)
|
||||||
|
outputs = outputs.numpy()
|
||||||
|
|
||||||
probs = postprocess(outputs)
|
probs = postprocess(outputs)
|
||||||
rank = 1
|
rank = 1
|
||||||
for idx, prob in probs:
|
print("Current image file: {}".format(filename))
|
||||||
print("top{:d}, class id: {:d}, probability: {:.4f}".format(rank, idx,
|
for idx, prob in probs:
|
||||||
prob))
|
print("\ttop{:d}, class id: {:d}, probability: {:.4f}".format(
|
||||||
rank += 1
|
rank, idx, prob))
|
||||||
|
rank += 1
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@ -15,13 +15,10 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import utils
|
import utils
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import logging
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from paddle.fluid.core import AnalysisConfig
|
from paddle.fluid.core import AnalysisConfig
|
||||||
from paddle.fluid.core import create_paddle_predictor
|
from paddle.fluid.core import create_paddle_predictor
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
@ -101,7 +98,6 @@ def main():
|
|||||||
else:
|
else:
|
||||||
assert args.use_gpu is True
|
assert args.use_gpu is True
|
||||||
assert args.model_name is not None
|
assert args.model_name is not None
|
||||||
assert args.use_tensorrt is True
|
|
||||||
# HALF precission predict only work when using tensorrt
|
# HALF precission predict only work when using tensorrt
|
||||||
if args.use_fp16 is True:
|
if args.use_fp16 is True:
|
||||||
assert args.use_tensorrt is True
|
assert args.use_tensorrt is True
|
||||||
@ -130,8 +126,9 @@ def main():
|
|||||||
output = output.flatten()
|
output = output.flatten()
|
||||||
cls = np.argmax(output)
|
cls = np.argmax(output)
|
||||||
score = output[cls]
|
score = output[cls]
|
||||||
logger.info("class: {0}".format(cls))
|
print("Current image file: {}".format(args.image_file))
|
||||||
logger.info("score: {0}".format(score))
|
print("\ttop-1 class: {0}".format(cls))
|
||||||
|
print("\ttop-1 score: {0}".format(score))
|
||||||
else:
|
else:
|
||||||
for i in range(0, test_num + 10):
|
for i in range(0, test_num + 10):
|
||||||
inputs = np.random.rand(args.batch_size, 3, 224,
|
inputs = np.random.rand(args.batch_size, 3, 224,
|
||||||
@ -145,11 +142,13 @@ def main():
|
|||||||
output = output.flatten()
|
output = output.flatten()
|
||||||
if i >= 10:
|
if i >= 10:
|
||||||
test_time += time.time() - start_time
|
test_time += time.time() - start_time
|
||||||
|
time.sleep(0.01) # sleep for T4 GPU
|
||||||
|
|
||||||
fp_message = "FP16" if args.use_fp16 else "FP32"
|
fp_message = "FP16" if args.use_fp16 else "FP32"
|
||||||
logger.info("{0}\t{1}\tbatch size: {2}\ttime(ms): {3}".format(
|
trt_msg = "using tensorrt" if args.use_tensorrt else "not using tensorrt"
|
||||||
args.model_name, fp_message, args.batch_size, 1000 * test_time /
|
print("{0}\t{1}\t{2}\tbatch size: {3}\ttime(ms): {4}".format(
|
||||||
test_num))
|
args.model_name, trt_msg, fp_message, args.batch_size, 1000 *
|
||||||
|
test_time / test_num))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
import utils
|
import utils
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -26,8 +27,6 @@ def parse_args():
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("-i", "--image_file", type=str)
|
parser.add_argument("-i", "--image_file", type=str)
|
||||||
parser.add_argument("-d", "--model_dir", type=str)
|
parser.add_argument("-d", "--model_dir", type=str)
|
||||||
parser.add_argument("-m", "--model_file", type=str)
|
|
||||||
parser.add_argument("-p", "--params_file", type=str)
|
|
||||||
parser.add_argument("--use_gpu", type=str2bool, default=True)
|
parser.add_argument("--use_gpu", type=str2bool, default=True)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
@ -41,10 +40,7 @@ def create_predictor(args):
|
|||||||
|
|
||||||
exe = fluid.Executor(place)
|
exe = fluid.Executor(place)
|
||||||
[program, feed_names, fetch_lists] = fluid.io.load_inference_model(
|
[program, feed_names, fetch_lists] = fluid.io.load_inference_model(
|
||||||
args.model_dir,
|
args.model_dir, exe, model_filename="model", params_filename="params")
|
||||||
exe,
|
|
||||||
model_filename=args.model_file,
|
|
||||||
params_filename=args.params_file)
|
|
||||||
compiled_program = fluid.compiler.CompiledProgram(program)
|
compiled_program = fluid.compiler.CompiledProgram(program)
|
||||||
|
|
||||||
return exe, compiled_program, feed_names, fetch_lists
|
return exe, compiled_program, feed_names, fetch_lists
|
||||||
@ -70,7 +66,6 @@ def preprocess(fname, ops):
|
|||||||
data = open(fname, 'rb').read()
|
data = open(fname, 'rb').read()
|
||||||
for op in ops:
|
for op in ops:
|
||||||
data = op(data)
|
data = op(data)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
@ -81,21 +76,40 @@ def postprocess(outputs, topk=5):
|
|||||||
return zip(index, prob[index])
|
return zip(index, prob[index])
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_list(img_file):
|
||||||
|
imgs_lists = []
|
||||||
|
if img_file is None or not os.path.exists(img_file):
|
||||||
|
raise Exception("not found any img file in {}".format(img_file))
|
||||||
|
|
||||||
|
img_end = ['jpg', 'png', 'jpeg', 'JPEG', 'JPG', 'bmp']
|
||||||
|
if os.path.isfile(img_file) and img_file.split('.')[-1] in img_end:
|
||||||
|
imgs_lists.append(img_file)
|
||||||
|
elif os.path.isdir(img_file):
|
||||||
|
for single_file in os.listdir(img_file):
|
||||||
|
if single_file.split('.')[-1] in img_end:
|
||||||
|
imgs_lists.append(os.path.join(img_file, single_file))
|
||||||
|
if len(imgs_lists) == 0:
|
||||||
|
raise Exception("not found any img file in {}".format(img_file))
|
||||||
|
return imgs_lists
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
operators = create_operators()
|
operators = create_operators()
|
||||||
exe, program, feed_names, fetch_lists = create_predictor(args)
|
exe, program, feed_names, fetch_lists = create_predictor(args)
|
||||||
|
|
||||||
data = preprocess(args.image_file, operators)
|
image_list = get_image_list(args.image_file)
|
||||||
data = np.expand_dims(data, axis=0)
|
for idx, filename in enumerate(image_list):
|
||||||
outputs = exe.run(program,
|
data = preprocess(filename, operators)
|
||||||
feed={feed_names[0]: data},
|
data = np.expand_dims(data, axis=0)
|
||||||
fetch_list=fetch_lists,
|
outputs = exe.run(program,
|
||||||
return_numpy=False)
|
feed={feed_names[0]: data},
|
||||||
probs = postprocess(outputs)
|
fetch_list=fetch_lists,
|
||||||
|
return_numpy=False)
|
||||||
for idx, prob in probs:
|
probs = postprocess(outputs)
|
||||||
print("class id: {:d}, probability: {:.4f}".format(idx, prob))
|
print("Current image file: {}".format(filename))
|
||||||
|
for idx, prob in probs:
|
||||||
|
print("\tclass id: {:d}, probability: {:.4f}".format(idx, prob))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
#!/usr/bin/env bash
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
|
||||||
|
|
||||||
python -m paddle.distributed.launch \
|
python -m paddle.distributed.launch \
|
||||||
--selected_gpus="0,1,2,3" \
|
--selected_gpus="0,1,2,3" \
|
||||||
tools/train.py \
|
tools/train.py \
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
#!/usr/bin/env bash
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
|
||||||
|
|
||||||
python tools/download.py -a ResNet34 -p ./pretrained/ -d 1
|
python tools/download.py -a ResNet34 -p ./pretrained/ -d 1
|
||||||
|
@ -13,19 +13,21 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
import program
|
||||||
|
from ppcls.utils import logger
|
||||||
|
from ppcls.utils.save_load import init_model, save_model
|
||||||
|
from ppcls.utils.config import get_config
|
||||||
|
from ppcls.data import Reader
|
||||||
|
import paddle.fluid as fluid
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import paddle.fluid as fluid
|
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
sys.path.append(__dir__)
|
||||||
from ppcls.data import Reader
|
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||||
from ppcls.utils.config import get_config
|
|
||||||
from ppcls.utils.save_load import init_model, save_model
|
|
||||||
from ppcls.utils import logger
|
|
||||||
import program
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
@ -49,8 +51,12 @@ def parse_args():
|
|||||||
def main(args):
|
def main(args):
|
||||||
config = get_config(args.config, overrides=args.override, show=True)
|
config = get_config(args.config, overrides=args.override, show=True)
|
||||||
# assign the place
|
# assign the place
|
||||||
gpu_id = fluid.dygraph.parallel.Env().dev_id
|
use_gpu = config.get("use_gpu", True)
|
||||||
place = fluid.CUDAPlace(gpu_id)
|
if use_gpu:
|
||||||
|
gpu_id = fluid.dygraph.ParallelEnv().dev_id
|
||||||
|
place = fluid.CUDAPlace(gpu_id)
|
||||||
|
else:
|
||||||
|
place = fluid.CPUPlace()
|
||||||
|
|
||||||
use_data_parallel = int(os.getenv("PADDLE_TRAINERS_NUM", 1)) != 1
|
use_data_parallel = int(os.getenv("PADDLE_TRAINERS_NUM", 1)) != 1
|
||||||
config["use_data_parallel"] = use_data_parallel
|
config["use_data_parallel"] = use_data_parallel
|
||||||
|
Loading…
x
Reference in New Issue
Block a user