CxxPredictor

class CxxPredictor

CxxPredictor是Paddle-Lite的预测器,由create_paddle_predictor根据CxxConfig进行创建。用户可以根据CxxPredictor提供的接口设置输入数据、执行模型预测、获取输出以及获得当前使用lib的版本信息等。

示例:

from paddlelite.lite import *

# 1. 设置CxxConfig
config = CxxConfig()
if args.model_file != '' and args.param_file != '':
    config.set_model_file(args.model_file)
    config.set_param_file(args.param_file)
else:
    config.set_model_dir(args.model_dir)
places = [Place(TargetType.ARM, PrecisionType.FP32)]
config.set_valid_places(places)

# 2. 创建CxxPredictor
predictor = create_paddle_predictor(config)

# 3. 设置输入数据
input_tensor = predictor.get_input(0)
input_tensor.resize([1, 3, 224, 224])
input_tensor.set_float_data([1.] * 3 * 224 * 224)

# 4. 运行模型
predictor.run()

# 5. 获取输出数据
output_tensor = predictor.get_output(0)
print(output_tensor.shape())
print(output_tensor.float_data()[:10])

get_input(index)

获取输入Tensor,用来设置模型的输入数据。

参数:

  • index(int) - 输入Tensor的索引

返回:第index个输入Tensor

返回类型:Tensor

get_output(index)

获取输出Tensor,用来获取模型的输出结果。

参数:

  • index(int) - 输出Tensor的索引

返回:第index个输出Tensor

返回类型:Tensor

run()

执行模型预测,需要在***设置输入数据后***调用。

参数:

  • None

返回:None

返回类型:None

get_version()

用于获取当前lib使用的代码版本。若代码有相应tag则返回tag信息,如v2.0-beta;否则返回代码的branch(commitid),如develop(7e44619)

参数:

  • None

返回:当前lib使用的代码版本信息

返回类型:str