20

【推理引擎】如何在 ONNXRuntime 中添加新的算子 - 虔诚的树

 3 years ago
source link: https://www.cnblogs.com/xxxxxxxxx/p/16078427.html
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

【推理引擎】如何在 ONNXRuntime 中添加新的算子

如果模型中有些算子不被ONNX算子库支持,我们就需要利用ONNXRuntime提供的API手动添加新算子。在官方文档中已经对如何添加定制算子进行了介绍(https://onnxruntime.ai/docs/reference/operators/add-custom-op.html ),这里我们主要把源码中对应的流程给捋清楚。

添加定制算子(Custom Operators)主要分为三步:

  1. 创建一个定制算子域(CusttomOpDomain);
  2. 创建一个定制算子(CustomOp),并将该算子添加到定制算子域中;
  3. 将定制算子域添加到 SessionOption 中

首先看看源码中给出的定制算子样例:

// file path: onnxruntime/test/shared_lib/custom_op_utils.h

// 首先定义定制算子的核
struct MyCustomKernel {
  MyCustomKernel(Ort::CustomOpApi ort, const OrtKernelInfo* /*info*/, void* compute_stream)
      : ort_(ort), compute_stream_(compute_stream) {
  }

  void Compute(OrtKernelContext* context);

 private:
  Ort::CustomOpApi ort_;
  void* compute_stream_;
};

// 然后定义定制算子的各个操作,各个成员函数均已实现,其中 CreateKernel 会返回前面定义的算子核对象
struct MyCustomOp : Ort::CustomOpBase<MyCustomOp, MyCustomKernel> {
  explicit MyCustomOp(const char* provider, void* compute_stream) : provider_(provider), compute_stream_(compute_stream) {}

  void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) const { return new MyCustomKernel(api, info, compute_stream_); };
  const char* GetName() const { return "Foo"; };
  const char* GetExecutionProviderType() const { return provider_; };

  size_t GetInputTypeCount() const { return 2; };
  ONNXTensorElementDataType GetInputType(size_t /*index*/) const {
    // Both the inputs need to be necessarily of float type
    return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
  };

  size_t GetOutputTypeCount() const { return 1; };
  ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };

 private:
  const char* provider_;
  void* compute_stream_;
};

在上面代码中,我们看到定制算子继承自 Ort::CustomOpBase<MyCustomOp, MyCustomKernel>,这种扩展类作为模板基类的模板参数的方式又被称为CRTP,接着深入到这个模板类内部:

// file path: include/onnxruntime/core/session/onnxruntime_cxx_api.h

template <typename TOp, typename TKernel>
struct CustomOpBase : OrtCustomOp {
  CustomOpBase() {
    OrtCustomOp::version = ORT_API_VERSION;
    OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast<const TOp*>(this_)->CreateKernel(*api, info); };
    OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetName(); };

    OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetExecutionProviderType(); };

    OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetInputTypeCount(); };
    OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputType(index); };

    OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetOutputTypeCount(); };
    OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputType(index); };

    OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { static_cast<TKernel*>(op_kernel)->Compute(context); };
    OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); };

    OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputCharacteristic(index); };
    OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index); };
  }

  // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
  const char* GetExecutionProviderType() const { return nullptr; }

  // Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below
  // (inputs and outputs are required by default)
  OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t /*index*/) const {
    return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
  }

  OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const {
    return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
  }
};

这里的 CustomOpBase 又继承自 OrtCustomOp

// include/onnxruntime/core/session/onnxruntime_c_api.h

struct OrtCustomOp;
typedef struct OrtCustomOp OrtCustomOp;

struct OrtCustomOp {
  uint32_t version;  // Must be initialized to ORT_API_VERSION

  // This callback creates the kernel, which is a user defined parameter that is passed to the Kernel* callbacks below.
  void*(ORT_API_CALL* CreateKernel)(_In_ const struct OrtCustomOp* op, _In_ const OrtApi* api,
                                    _In_ const OrtKernelInfo* info);

  // Returns the name of the op
  const char*(ORT_API_CALL* GetName)(_In_ const struct OrtCustomOp* op);

  // Returns the type of the execution provider, return nullptr to use CPU execution provider
  const char*(ORT_API_CALL* GetExecutionProviderType)(_In_ const struct OrtCustomOp* op);

  // Returns the count and types of the input & output tensors
  ONNXTensorElementDataType(ORT_API_CALL* GetInputType)(_In_ const struct OrtCustomOp* op, _In_ size_t index);
  size_t(ORT_API_CALL* GetInputTypeCount)(_In_ const struct OrtCustomOp* op);
  ONNXTensorElementDataType(ORT_API_CALL* GetOutputType)(_In_ const struct OrtCustomOp* op, _In_ size_t index);
  size_t(ORT_API_CALL* GetOutputTypeCount)(_In_ const struct OrtCustomOp* op);

  // Op kernel callbacks
  void(ORT_API_CALL* KernelCompute)(_In_ void* op_kernel, _In_ OrtKernelContext* context);
  void(ORT_API_CALL* KernelDestroy)(_In_ void* op_kernel);

  // Returns the characteristics of the input & output tensors
  OrtCustomOpInputOutputCharacteristic(ORT_API_CALL* GetInputCharacteristic)(_In_ const struct OrtCustomOp* op, _In_ size_t index);
  OrtCustomOpInputOutputCharacteristic(ORT_API_CALL* GetOutputCharacteristic)(_In_ const struct OrtCustomOp* op, _In_ size_t index);
};

可以发现,OrtCustomOp 中定义了定制算子应该实现的模式,其中的一系列回调函数由其派生类一一实现,比如上文提到的 CustomOpBase 在其构造函数中,以 lambda 函数的方式实现各个回调函数。

至此,我们已经完整地梳理了定义定制算子在源码内部是如何实现的,接下来介绍如何将定义好的定制算子使用起来。

从如下官方测试代码开始分析:

// file path: onnxruntime/test/shared_lib/test_inference.cc

TEST(CApiTest, custom_op_handler) {
  std::cout << "Running custom op inference" << std::endl;

  std::vector<Input> inputs(1);
  Input& input = inputs[0];
  input.name = "X";
  input.dims = {3, 2};
  input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};

  // prepare expected inputs and outputs
  std::vector<int64_t> expected_dims_y = {3, 2};
  std::vector<float> expected_values_y = {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f};

  // 创建定制算子(MyCustomOp)
#ifdef USE_CUDA
  cudaStream_t compute_stream = nullptr;    // 声明一个 cuda stream
  cudaStreamCreateWithFlags(&compute_stream, cudaStreamNonBlocking);  // 创建一个 cuda stream
  MyCustomOp custom_op{onnxruntime::kCudaExecutionProvider, compute_stream};
#else
  MyCustomOp custom_op{onnxruntime::kCpuExecutionProvider, nullptr};
#endif
  
  // 创建定制算子域(CustomOpDomain)
  Ort::CustomOpDomain custom_op_domain("");
  // 在定制算子域中添加定制算子
  custom_op_domain.Add(&custom_op);

  // 进入 TestInference
#ifdef USE_CUDA
  TestInference<float>(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 1,
                       custom_op_domain, nullptr, nullptr, false, compute_stream);
  cudaStreamDestroy(compute_stream);
#else
  TestInference<float>(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 0,
                       custom_op_domain, nullptr);
#endif
}

以上代码需要特别注意的是,需要根据宏(USE_CUDA)用来判断是否使用CUDA。如果使用 CUDA:

  • 当模型运行在GPU上,而插入的是 CPU 定制算子,那么 ONNXRuntime 会在 CPU 定制算子前后分别插入两个操作 MemcpyToHost、MemcpyFromHost,这两个操作负责内存拷贝,即首先从 Device 拷贝到 Host,再从 Host 拷贝到 Device;
  • 如果插入的是 GPU 定制算子,为了确保 ORT 的 CUDA kernels 和定制 CUDA kernels 之间的同步,它们必须使用同一个 CUDA 计算流。具体细节在下一个代码继续分析。

这里创建 cuda stream 的方式是 cudaStreamCreateWithFlags,该函数和 cudaStreamCreate 不同,后者在多次调用时是串行方式执行,而前者可同步执行。如果将参数 cudaStreamNonBlocking 替换为 cudaStreamDefault,则 cudaStreamCreateWithFlags 的行为将和 cudaStreamCreate 相同。【参考内容:CUDA 5.0 中cudaStreamCreateWithFlags 的用法

无论是否使用CDUA,我们都需要创建定制算子(MyCustomOp)。

进入 TestInference 函数内部:

// file path: onnxruntime/test/shared_lib/test_inference.cc

template <typename OutT>
static void TestInference(Ort::Env& env, const std::basic_string<ORTCHAR_T>& model_uri,
                          const std::vector<Input>& inputs,
                          const char* output_name,
                          const std::vector<int64_t>& expected_dims_y,
                          const std::vector<OutT>& expected_values_y,
                          int provider_type,
                          OrtCustomOpDomain* custom_op_domain_ptr,
                          const char* custom_op_library_filename,
                          void** library_handle = nullptr,
                          bool test_session_creation_only = false,
                          void* cuda_compute_stream = nullptr) {
  Ort::SessionOptions session_options;

  if (provider_type == 1) {
#ifdef USE_CUDA
    std::cout << "Running simple inference with cuda provider" << std::endl;
    auto cuda_options = CreateDefaultOrtCudaProviderOptionsWithCustomStream(cuda_compute_stream);
    session_options.AppendExecutionProvider_CUDA(cuda_options);
#else
    ORT_UNUSED_PARAMETER(cuda_compute_stream);
    return;
#endif
  } else if (provider_type == 2) {
#ifdef USE_DNNL
    Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Dnnl(session_options, 1));
    std::cout << "Running simple inference with dnnl provider" << std::endl;
#else
    return;
#endif
  } else if (provider_type == 3) {
#ifdef USE_NUPHAR
    Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nuphar(session_options,
                                                                      /*allow_unaligned_buffers*/ 1, ""));
    std::cout << "Running simple inference with nuphar provider" << std::endl;
#else
    return;
#endif
  } else {
    std::cout << "Running simple inference with default provider" << std::endl;
  }
  if (custom_op_domain_ptr) {
    session_options.Add(custom_op_domain_ptr);
  }

  if (custom_op_library_filename) {
    Ort::ThrowOnError(Ort::GetApi().RegisterCustomOpsLibrary(session_options,
                                                             custom_op_library_filename, library_handle));
  }

  // if session creation passes, model loads fine
  Ort::Session session(env, model_uri.c_str(), session_options);

  // caller wants to test running the model (not just loading the model)
  if (!test_session_creation_only) {
    // Now run
    auto default_allocator = std::make_unique<MockedOrtAllocator>();

    //without preallocated output tensor
    RunSession<OutT>(default_allocator.get(),
                     session,
                     inputs,
                     output_name,
                     expected_dims_y,
                     expected_values_y,
                     nullptr);
    //with preallocated output tensor
    Ort::Value value_y = Ort::Value::CreateTensor<float>(default_allocator.get(),
                                                         expected_dims_y.data(), expected_dims_y.size());

    //test it twice
    for (int i = 0; i != 2; ++i)
      RunSession<OutT>(default_allocator.get(),
                       session,
                       inputs,
                       output_name,
                       expected_dims_y,
                       expected_values_y,
                       &value_y);
  }
}

前文提到,如果对应EP是CUDA,需要确保 ORT 的 CUDA kernels 和定制 CUDA kernels 之间的同步。为了实现这一目标,首先通过 CreateDefaultOrtCudaProviderOptionsWithCustomStream 函数将新创建的 CUDA 计算流以 OrtCudaProviderOptions 的形式传递给 SessionOptions:

OrtCUDAProviderOptions cuda_options = CreateDefaultOrtCudaProviderOptionsWithCustomStream(cuda_compute_stream);
session_options.AppendExecutionProvider_CUDA(cuda_options)

之后,将定制算子域也添加到 SessionOptions 中:

if (custom_op_domain_ptr) {
  session_options.Add(custom_op_domain_ptr);
}

至此,SessionOptions 已经构建完成,下面创建 Session 并通过 model_uri 加载模型:

Ort::Session session(env, model_uri.c_str(), session_options);

这里的(1)Ort::Session 是在 onnxruntime_cxx_api.h 文件中声明的类,(2)对应的构造函数在 onnxruntime_cxx_inline.h 中实现,(3)实现方式是进一步调用 onnxruntime_c_api.h 中定义的 API,该 API 也仅仅是声明,(4)最终对应的实现在 onnxruntime_c_api.cc 文件中:

// (1) include/onnxruntime/core/session/onnxruntime_cxx_api.h
struct Session : Base<OrtSession> {
  explicit Session(std::nullptr_t) {}
  Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options);
}

// (2) include/onnxruntime/core/session/onnxruntime_cxx_inline.h
inline Session::Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) {
  ThrowOnError(GetApi().CreateSession(env, model_path, options, &p_));
}

// (3) include/onnxruntime/core/session/onnxruntime_c_api.h
ORT_API2_STATUS(CreateSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path,
                _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out);

// (4) onnxruntime/core/session/onnxruntime_c_api.cc
ORT_API_STATUS_IMPL(OrtApis::CreateSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path,
                    _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out) {
  API_IMPL_BEGIN
  std::unique_ptr<onnxruntime::InferenceSession> sess;
  OrtStatus* status = nullptr;
  *out = nullptr;

  ORT_TRY {
    ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, model_path, nullptr, 0, sess));
    ORT_API_RETURN_IF_ERROR(InitializeSession(options, sess));

    *out = reinterpret_cast<OrtSession*>(sess.release());
  }
  ORT_CATCH(const std::exception& e) {
    ORT_HANDLE_EXCEPTION([&]() {
      status = OrtApis::CreateStatus(ORT_FAIL, e.what());
    });
  }

  return status;
  API_IMPL_END
}

可以发现,Ort::Session 内部还是调用了 onnxruntime::InferenceSession

扯远了,下面回归主题。

创建 Session 完成之后,便开始运行,进入 RunSession 函数内部:

// file path: onnxruntime/test/shared_lib/test_inference.cc

template <typename OutT>
void RunSession(OrtAllocator* allocator, Ort::Session& session_object,
                const std::vector<Input>& inputs,
                const char* output_name,
                const std::vector<int64_t>& dims_y,
                const std::vector<OutT>& values_y,
                Ort::Value* output_tensor) {
  
  // 构建模型输入
  std::vector<Ort::Value> ort_inputs;
  std::vector<const char*> input_names;
  for (size_t i = 0; i < inputs.size(); i++) {
    input_names.emplace_back(inputs[i].name);
    ort_inputs.emplace_back(
        Ort::Value::CreateTensor<float>(allocator->Info(allocator), const_cast<float*>(inputs[i].values.data()),
                                        inputs[i].values.size(), inputs[i].dims.data(), inputs[i].dims.size()));
  }
  
  // 运行 RUN
  std::vector<Ort::Value> ort_outputs;
  if (output_tensor)
    session_object.Run(Ort::RunOptions{nullptr}, input_names.data(), ort_inputs.data(), ort_inputs.size(),
                       &output_name, output_tensor, 1);
  else {
    ort_outputs = session_object.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(),
                                     &output_name, 1);
    ASSERT_EQ(ort_outputs.size(), 1u);
    output_tensor = &ort_outputs[0];
  }

  auto type_info = output_tensor->GetTensorTypeAndShapeInfo();
  ASSERT_EQ(type_info.GetShape(), dims_y);
  size_t total_len = type_info.GetElementCount();
  ASSERT_EQ(values_y.size(), total_len);

  OutT* f = output_tensor->GetTensorMutableData<OutT>();
  for (size_t i = 0; i != total_len; ++i) {
    ASSERT_EQ(values_y[i], f[i]);
  }
}

这里使用了一些GTest中的断言来判定运行结果是否符合预期。

至此,我们已经完整地分析了定制算子从定义到使用的全部流程。

文档中还提到了 Contrib ops,这类算子归属于 contrib ops domain,是嵌入到 runtime 内部的,对于一些使用低频的算子最好不要加入这个域中,否则会导致运行时库(runtime library)过大。
官方文档中给出了添加算子到这个域中的方法,这里就不再进行介绍了,以后用到了再说吧。


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK