Advertisement
alkkofficial

dummy executor async

Jun 7th, 2024 (edited)
877
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Rust 3.09 KB | None | 0 0
  1. use flume::Sender;
  2. use futures::executor::ThreadPool;
  3. use rand::Rng;
  4. use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
  5. use std::{collections::VecDeque, env, thread};
  6. use tch::Tensor;
  7. use tz_rust::{decoder::eval_state, executor::ExecutorDebugger, mcts_trainer::Net};
  8.  
  9. fn main() {
  10.     env::set_var("RUST_BACKTRACE", "1");
  11.     let pool = ThreadPool::new().expect("Failed to build pool");
  12.  
  13.     let batch_size: usize = 256;
  14.     let num_executors: usize = 2;
  15.     let num_generators: usize = batch_size * num_executors * 4;
  16.  
  17.     crossbeam::scope(|s| {
  18.         let (tensor_sender, tensor_receiver) = flume::bounded::<Tensor>(num_generators);
  19.  
  20.         for _ in 0..num_generators {
  21.             let tensor_sender_clone = tensor_sender.clone();
  22.             let fut_generator = async move { dummy_generator(tensor_sender_clone).await };
  23.             pool.spawn_ok(fut_generator);
  24.         }
  25.  
  26.         for i in 0..num_executors {
  27.             let tensor_receiver_clone = tensor_receiver.clone();
  28.             s.builder()
  29.                 .name(format!("executor-{}", i))
  30.                 .spawn(move |_| {
  31.                     let net = Net::new("tz_6515.pt");
  32.  
  33.                     let thread_name = thread::current()
  34.                         .name()
  35.                         .unwrap_or("unnamed-executor")
  36.                         .to_owned();
  37.  
  38.                     let mut input_vec: VecDeque<Tensor> = VecDeque::new();
  39.                     let mut one_sec_timer = Instant::now();
  40.                     let mut eval_counter = 0;
  41.  
  42.                     let mut debugger = ExecutorDebugger::create_debug();
  43.  
  44.                     loop {
  45.                         let data = tensor_receiver_clone.recv().unwrap();
  46.                         input_vec.push_back(data);
  47.                         if input_vec.len() == batch_size {
  48.                             debugger.record("waiting_for_batch", &thread_name);
  49.  
  50.                             let input_tensors = Tensor::cat(&input_vec.make_contiguous(), 0);
  51.  
  52.                             let eval_debugger = ExecutorDebugger::create_debug();
  53.                             let _ = eval_state(input_tensors, &net).expect("Error");
  54.                             eval_debugger.record("evaluation_time_taken", &thread_name);
  55.  
  56.                             input_vec.clear();
  57.  
  58.                             eval_counter += 1;
  59.                             debugger.reset();
  60.                         }
  61.                         if one_sec_timer.elapsed() > Duration::from_secs(1) {
  62.                             eval_counter = 0;
  63.                             one_sec_timer = Instant::now();
  64.                         }
  65.                     }
  66.                 })
  67.                 .unwrap();
  68.         }
  69.     })
  70.     .unwrap();
  71. }
  72.  
  73. fn random_tensor(size: usize) -> Tensor {
  74.     let mut rng = rand::thread_rng();
  75.     let data: Vec<f32> = (0..size).map(|_| rng.gen::<f32>()).collect();
  76.     Tensor::from_slice(&data)
  77. }
  78.  
  79. async fn dummy_generator(tensor_sender_clone: Sender<Tensor>) {
  80.     loop {
  81.         let data = random_tensor(1344 * 1);
  82.         tensor_sender_clone.send_async(data).await.unwrap();
  83.     }
  84. }
  85.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement