dan-masek

Untitled

Mar 31st, 2020
515
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 5.12 KB | None | 0 0
  1. #include <pybind11/pybind11.h>
  2. #include <pybind11/embed.h>
  3. #include <pybind11/numpy.h>
  4. #include <pybind11/stl.h>
  5.  
  6. #include <opencv2/opencv.hpp>
  7.  
  8. #include <iostream>
  9.  
  10. namespace py = pybind11;
  11.  
  12. // ============================================================================
  13. int determine_cv_type(pybind11::dtype const& dtype)
  14. {
  15.     switch (dtype.kind()) {
  16.     case 'u':
  17.         switch (dtype.itemsize()) {
  18.         case 1: return CV_8U;
  19.         case 2: return CV_16U;
  20.         default:
  21.             throw std::invalid_argument("Unsupported unsigned integer size.");
  22.         }
  23.     case 'i':
  24.         switch (dtype.itemsize()) {
  25.         case 1: return CV_8S;
  26.         case 2: return CV_16S;
  27.         case 4: return CV_32S;
  28.         default:
  29.             throw std::invalid_argument("Unsupported signed integer size.");
  30.         }
  31.     case 'f':
  32.         switch (dtype.itemsize()) {
  33.         case 4: return CV_32F;
  34.         case 8: return CV_64F;
  35.         default:
  36.             throw std::invalid_argument("Unsupported float size.");
  37.         }
  38.     default:
  39.         throw std::invalid_argument("Unsupported array dtype.");
  40.     }
  41. }
  42. // ----------------------------------------------------------------------------
  43. cv::Mat nparray_to_mat(py::array& arr)
  44. {
  45.     auto const depth = determine_cv_type(arr.dtype());
  46.  
  47.     switch (arr.ndim()) {
  48.     case 1:
  49.         return cv::Mat(
  50.             static_cast<int>(arr.shape(0)) // Rows
  51.             , 1 // Columns
  52.             , CV_MAKETYPE(depth, 1) // Data type
  53.             , arr.mutable_data()
  54.         );
  55.     case 2:
  56.         return cv::Mat(
  57.             static_cast<int>(arr.shape(0)) // Rows
  58.             , static_cast<int>(arr.shape(1)) // Columns
  59.             , CV_MAKETYPE(depth, 1) // Data type
  60.             , arr.mutable_data()
  61.         );
  62.     case 3:
  63.         if (arr.shape(2) > CV_CN_MAX) {
  64.             std::invalid_argument("Too many array channels.");
  65.         }
  66.         return cv::Mat(
  67.             static_cast<int>(arr.shape(0)) // Rows
  68.             , static_cast<int>(arr.shape(1)) // Columns
  69.             , CV_MAKETYPE(depth, static_cast<int>(arr.shape(2))) // Data type
  70.             , arr.mutable_data()
  71.         );
  72.     default:
  73.         throw std::invalid_argument("Only 1D, 2D and 3D arrays supported.");
  74.     }
  75. }
  76. // ============================================================================
  77. py::dtype determine_np_dtype(int depth)
  78. {
  79.     switch (depth) {
  80.     case CV_8U: return py::dtype::of<uint8_t>();
  81.     case CV_8S: return py::dtype::of<int8_t>();
  82.     case CV_16U: return py::dtype::of<uint16_t>();
  83.     case CV_16S: return py::dtype::of<int16_t>();
  84.     case CV_32S: return py::dtype::of<int32_t>();
  85.     case CV_32F: return py::dtype::of<float>();
  86.     case CV_64F: return py::dtype::of<double>();
  87.     default:
  88.         throw std::invalid_argument("Unsupported data type.");
  89.     }
  90. }
  91. // ----------------------------------------------------------------------------
  92. py::capsule make_capsule(cv::Mat& m)
  93. {
  94.     return py::capsule(new cv::Mat(m)
  95.         , [](void *v) { delete reinterpret_cast<cv::Mat*>(v); }
  96.         );
  97. }
  98. // ----------------------------------------------------------------------------
  99. py::array mat_to_nparray(cv::Mat& m)
  100. {
  101.     if (!m.isContinuous()) {
  102.         throw std::invalid_argument("Only continuous Mats supported.");
  103.     }
  104.  
  105.     std::vector<std::size_t> shape;
  106.    
  107.     if (m.channels() == 1) {
  108.         shape = {
  109.             static_cast<size_t>(m.rows)
  110.             , static_cast<size_t>(m.cols)
  111.         };
  112.     } else {
  113.         shape = {
  114.             static_cast<size_t>(m.rows)
  115.             , static_cast<size_t>(m.cols)
  116.             , static_cast<size_t>(m.channels())
  117.         };
  118.     }
  119.  
  120.     return py::array(determine_np_dtype(m.depth())
  121.         , shape
  122.         , m.data
  123.         , make_capsule(m));
  124. }
  125. // ============================================================================
  126. PYBIND11_EMBEDDED_MODULE(test_module, m)
  127. {
  128.     m.doc() = "Test module";
  129. }
  130. // ============================================================================
  131. int main()
  132. {
  133.     // Start the interpreter and keep it alive
  134.     py::scoped_interpreter guard{};
  135.  
  136.     try {
  137.         auto locals = py::dict{};
  138.  
  139.         py::exec(R"(
  140.            import numpy as np
  141.  
  142.            def test_cpp_to_py(arr):
  143.                return (1,2,3)
  144.        )");
  145.  
  146.         auto test_cpp_to_py = py::globals()["test_cpp_to_py"];
  147.  
  148.  
  149.         for (int i = 0; i < 10; i++) {
  150.             int64 t0 = cv::getTickCount();
  151.  
  152.             cv::Mat img(cv::Mat::zeros(1024, 1024, CV_8UC3) + cv::Scalar(1, 1, 1));
  153.  
  154.             int64 t1 = cv::getTickCount();
  155.  
  156.             auto result = test_cpp_to_py(mat_to_nparray(img));
  157.  
  158.             int64 t2 = cv::getTickCount();
  159.  
  160.             double delta0 = (t1 - t0) / cv::getTickFrequency() * 1000;
  161.             double delta1 = (t2 - t1) / cv::getTickFrequency() * 1000;
  162.  
  163.             std::cout << "* " << delta0 << " ms | " << delta1 << " ms" << std::endl;
  164.         }        
  165.     } catch (py::error_already_set& e) {
  166.         std::cerr << e.what() << "\n";
  167.     }
  168.    
  169.     return 0;
  170. }
  171. // ============================================================================
Add Comment
Please, Sign In to add comment