Advertisement
alkkofficial

dummy executor minimum

Jun 7th, 2024
925
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Rust 3.66 KB | None | 0 0
  1. use flume::{Receiver, Sender};
  2. use rand::Rng;
  3. use std::{
  4.     collections::VecDeque,
  5.     env,
  6.     time::{Duration, Instant},
  7. };
  8. use tch::Tensor;
  9. use tz_rust::{decoder::eval_state, mcts_trainer::Net};
  10.  
  11. fn main() {
  12.     env::set_var("RUST_BACKTRACE", "1");
  13.     const NUM_LOOPS: usize = 100;
  14.     const NUM_WARMUPS: usize = 100;
  15.     const BATCH_SIZE: usize = 512;
  16.     const NUM_EXECUTORS: usize = 2;
  17.     const NUM_GENERATORS: usize = 1024;
  18.  
  19.     let entire_benchmark_timer = Instant::now();
  20.  
  21.     crossbeam::scope(|s| {
  22.         // send/recv pairs between executors and generators
  23.         let (tensor_sender, tensor_receiver) = flume::bounded::<Tensor>(NUM_GENERATORS); // dummy generator to executor
  24.  
  25.         // spawn the dummy generators
  26.         for i in 0..NUM_GENERATORS {
  27.             let tensor_sender_clone = tensor_sender.clone();
  28.             s.builder()
  29.                 .name(format!("thread-{}", i + 1))
  30.                 .spawn(move |_| loop {
  31.                     let data = random_tensor(1344 * 1); // sending dummy data one at a time, 1d (for 1 dummy position)
  32.                     tensor_sender_clone.send(data).unwrap();
  33.                 })
  34.                 .unwrap();
  35.         }
  36.  
  37.         // spawn the executors
  38.         for i in 0..NUM_EXECUTORS {
  39.             let tensor_receiver_clone = tensor_receiver.clone();
  40.             s.builder()
  41.                 .name(format!("thread-{}", i + 1))
  42.                 .spawn(move |_| {
  43.                     let net = Net::new("tz_6515.pt"); // Create a new instance of Net within the thread
  44.  
  45.                     // Warmup loop
  46.                     for _ in 0..NUM_WARMUPS {
  47.                         let data = random_tensor(1344 * BATCH_SIZE); // 8*8*21 = 1344
  48.                         let _ = eval_state(data, &net).expect("Error");
  49.                     }
  50.  
  51.                     let thread_name = std::thread::current()
  52.                         .name()
  53.                         .unwrap_or("unnamed-executor")
  54.                         .to_owned();
  55.  
  56.                     // Timed, benchmarked loop
  57.                     let full_run = Instant::now();
  58.                     let mut input_vec: VecDeque<Tensor> = VecDeque::new();
  59.                     let mut one_sec_timer = Instant::now();
  60.                     let mut eval_counter = 0; // keep track of the number of foward passes in a second
  61.                     loop {
  62.                         let data = tensor_receiver_clone.recv().unwrap();
  63.                         input_vec.push_back(data);
  64.                         if input_vec.len() == BATCH_SIZE {
  65.                             let i_v = input_vec.make_contiguous();
  66.                             let input_tensors = Tensor::cat(&i_v, 0);
  67.                             let _ = eval_state(input_tensors, &net).expect("Error");
  68.                             input_vec.clear();
  69.                             // calculate and display evals/s
  70.                             eval_counter += 1;
  71.                         }
  72.                         if one_sec_timer.elapsed() > Duration::from_secs(1) {
  73.                             println!("{} {}evals/s", thread_name, BATCH_SIZE * eval_counter);
  74.                             eval_counter = 0;
  75.                             one_sec_timer = Instant::now();
  76.                         }
  77.                     }
  78.                 })
  79.                 .unwrap();
  80.         }
  81.     })
  82.     .unwrap();
  83.  
  84.     let total_time_secs = entire_benchmark_timer.elapsed().as_nanos() as f32 / 1e9;
  85.     println!("Benchmark ran for {}s", total_time_secs);
  86. }
  87.  
  88. fn random_tensor(size: usize) -> Tensor {
  89.     let mut rng = rand::thread_rng();
  90.     let data: Vec<f32> = (0..size).map(|_| rng.gen::<f32>()).collect();
  91.     Tensor::from_slice(&data)
  92. }
  93.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement