Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- use flume::{Receiver, Sender};
- use rand::Rng;
- use std::{
- collections::VecDeque,
- env,
- time::{Duration, Instant},
- };
- use tch::Tensor;
- use tz_rust::{decoder::eval_state, mcts_trainer::Net};
- fn main() {
- env::set_var("RUST_BACKTRACE", "1");
- const NUM_LOOPS: usize = 100;
- const NUM_WARMUPS: usize = 100;
- const BATCH_SIZE: usize = 512;
- const NUM_EXECUTORS: usize = 2;
- const NUM_GENERATORS: usize = 1024;
- let entire_benchmark_timer = Instant::now();
- crossbeam::scope(|s| {
- // send/recv pairs between executors and generators
- let (tensor_sender, tensor_receiver) = flume::bounded::<Tensor>(NUM_GENERATORS); // dummy generator to executor
- // spawn the dummy generators
- for i in 0..NUM_GENERATORS {
- let tensor_sender_clone = tensor_sender.clone();
- s.builder()
- .name(format!("thread-{}", i + 1))
- .spawn(move |_| loop {
- let data = random_tensor(1344 * 1); // sending dummy data one at a time, 1d (for 1 dummy position)
- tensor_sender_clone.send(data).unwrap();
- })
- .unwrap();
- }
- // spawn the executors
- for i in 0..NUM_EXECUTORS {
- let tensor_receiver_clone = tensor_receiver.clone();
- s.builder()
- .name(format!("thread-{}", i + 1))
- .spawn(move |_| {
- let net = Net::new("tz_6515.pt"); // Create a new instance of Net within the thread
- // Warmup loop
- for _ in 0..NUM_WARMUPS {
- let data = random_tensor(1344 * BATCH_SIZE); // 8*8*21 = 1344
- let _ = eval_state(data, &net).expect("Error");
- }
- let thread_name = std::thread::current()
- .name()
- .unwrap_or("unnamed-executor")
- .to_owned();
- // Timed, benchmarked loop
- let full_run = Instant::now();
- let mut input_vec: VecDeque<Tensor> = VecDeque::new();
- let mut one_sec_timer = Instant::now();
- let mut eval_counter = 0; // keep track of the number of foward passes in a second
- loop {
- let data = tensor_receiver_clone.recv().unwrap();
- input_vec.push_back(data);
- if input_vec.len() == BATCH_SIZE {
- let i_v = input_vec.make_contiguous();
- let input_tensors = Tensor::cat(&i_v, 0);
- let _ = eval_state(input_tensors, &net).expect("Error");
- input_vec.clear();
- // calculate and display evals/s
- eval_counter += 1;
- }
- if one_sec_timer.elapsed() > Duration::from_secs(1) {
- println!("{} {}evals/s", thread_name, BATCH_SIZE * eval_counter);
- eval_counter = 0;
- one_sec_timer = Instant::now();
- }
- }
- })
- .unwrap();
- }
- })
- .unwrap();
- let total_time_secs = entire_benchmark_timer.elapsed().as_nanos() as f32 / 1e9;
- println!("Benchmark ran for {}s", total_time_secs);
- }
- fn random_tensor(size: usize) -> Tensor {
- let mut rng = rand::thread_rng();
- let data: Vec<f32> = (0..size).map(|_| rng.gen::<f32>()).collect();
- Tensor::from_slice(&data)
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement