Advertisement
here2share

pybind11_tester

Apr 5th, 2021
1,173
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.52 KB | None | 0 0
  1. ##### MakeLists.txt
  2. cmake_minimum_required(VERSION 3.4...3.18)
  3. project(pybindtest)
  4. add_subdirectory(pybind11)
  5. pybind11_add_module(module_name main.cpp)
  6.  
  7.  
  8.  
  9. ##### main.cpp
  10. #include <vector>
  11. #include <pybind11/pybind11.h>
  12. #include <pybind11/stl.h>
  13. #include <pybind11/numpy.h>
  14. #include <chrono>
  15. #include <thread>
  16.  
  17. namespace py = pybind11;
  18.  
  19. float some_fn(float arg1, float arg2) {
  20.   return arg1 + arg2;
  21. }
  22.  
  23. class SomeClass {
  24.   float multiplier;
  25. public:
  26.   SomeClass(float multiplier_) : multiplier(multiplier_) {};
  27.  
  28.   float multiply(float input) {
  29.     return multiplier * input;
  30.   }
  31.  
  32.   std::vector<float> multiply_list(std::vector<float> items) {
  33.     for (auto i = 0; i < items.size(); i++) {
  34.       items[i] = multiply(items.at(i));
  35.     }
  36.     return items;
  37.   }
  38.  
  39.   // py::tuple multiply_two(float one, float two) {
  40.   //   return py::make_tuple(multiply(one), multiply(two));
  41.   // }
  42.  
  43.   std::vector<std::vector<uint8_t>> make_image() {
  44.     auto out = std::vector<std::vector<uint8_t>>();
  45.     for (auto i = 0; i < 128; i++) {
  46.       out.push_back(std::vector<uint8_t>(64));
  47.     }
  48.     for (auto i = 0; i < 30; i++) {
  49.       for (auto j = 0; j < 30; j++) { out[i][j] = 255; }
  50.     }
  51.     return out;
  52.   }
  53.  
  54.   void set_mult(float val) {
  55.     multiplier = val;
  56.   }
  57.  
  58.   float get_mult() {
  59.     return multiplier;
  60.   }
  61.  
  62.   void function_that_takes_a_while() {
  63.     py::gil_scoped_release release;
  64.     std::cout << "starting" << std::endl;
  65.     std::this_thread::sleep_for(std::chrono::milliseconds(2000));
  66.     std::cout << "ended" << std::endl;
  67.  
  68.     // py::gil_scoped_acquire acquire;
  69.     // auto list = py::list();
  70.     // list.append(1);
  71.   }
  72. };
  73.  
  74. SomeClass some_class_factory(float multiplier) {
  75.   return SomeClass(multiplier);
  76. }
  77.  
  78.  
  79. PYBIND11_MODULE(module_name, module_handle) {
  80.   module_handle.doc() = "I'm a docstring hehe";
  81.   module_handle.def("some_fn_python_name", &some_fn);
  82.   module_handle.def("some_class_factory", &some_class_factory);
  83.   py::class_<SomeClass>(
  84.             module_handle, "PySomeClass"
  85.             ).def(py::init<float>())
  86.     .def_property("multiplier", &SomeClass::get_mult, &SomeClass::set_mult)
  87.     .def("multiply", &SomeClass::multiply)
  88.     .def("multiply_list", &SomeClass::multiply_list)
  89.     // .def_property_readonly("image", &SomeClass::make_image)
  90.     .def_property_readonly("image", [](SomeClass &self) {
  91.                       py::array out = py::cast(self.make_image());
  92.                       return out;
  93.                     })
  94.     // .def("multiply_two", &SomeClass::multiply_two)
  95.     .def("multiply_two", [](SomeClass &self, float one, float two) {
  96.                return py::make_tuple(self.multiply(one), self.multiply(two));
  97.              })
  98.     .def("function_that_takes_a_while", &SomeClass::function_that_takes_a_while)
  99.     ;
  100. }
  101.  
  102.  
  103.  
  104. ##### test.py
  105. import time
  106. import traceback
  107. import cv2
  108. from build.module_name import *
  109.  
  110. from concurrent.futures import ThreadPoolExecutor
  111.  
  112. def call_and_print_exc(fn):
  113.     try:
  114.         fn()
  115.     except Exception:
  116.         traceback.print_exc()
  117.  
  118. print(PySomeClass)
  119.  
  120.  
  121. m = some_class_factory(10)
  122.  
  123. m2 = PySomeClass(10)
  124.  
  125. print(m, m2)
  126.  
  127. print(m.multiply(20))
  128.  
  129. # print(m.multiply("20"))
  130.  
  131. arr = m.multiply_list([0.0, 1.0, 2.0, 3.0])
  132.  
  133. print(arr)
  134.  
  135. print(m.multiply_two(50, 200))
  136.  
  137. print(m.image)
  138.  
  139. print(m.image.shape)
  140.  
  141. cv2.imwrite("/tmp/test.png", m.image)
  142.  
  143. print(m.multiplier)
  144.  
  145. m.multiplier = 100
  146.  
  147. print(m.multiplier)
  148.  
  149. start = time.time()
  150.  
  151. with ThreadPoolExecutor(4) as ex:
  152.     ex.map(lambda x: m.function_that_takes_a_while(), [None]*4)
  153.  
  154. print(f"Threaded fun took {time.time() - start} seconds")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement