Advertisement
alkkofficial

Rust MCTS eval_and_expand

Sep 2nd, 2023
161
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Rust 3.31 KB | None | 0 0
  1. pub fn eval_board(
  2.     bs: &BoardStack,
  3.     net: &Net,
  4.     tree: &mut Tree,
  5.     selected_node_idx: &usize,
  6. ) -> Vec<usize> {
  7.     let contents = get_contents();
  8.     let b = convert_board(bs);
  9.  
  10.     let output = eval_state(b, &net).expect("Error");
  11.  
  12.     let (board_eval, policy) = output; // check policy, eval ordering!
  13.  
  14.     let board_eval = board_eval.squeeze();
  15.  
  16.     let board_eval: Vec<f32> = Vec::try_from(board_eval).expect("Error");
  17.  
  18.     let board_eval = Tensor::from_slice(&vec![board_eval[0]]);
  19.  
  20.     let value = Tensor::tanh(&board_eval);
  21.  
  22.     let policy = policy.squeeze();
  23.     let policy: Vec<f32> = Vec::try_from(policy).expect("Error");
  24.     let value = f32::try_from(value).expect("Error");
  25.  
  26.     let value = match bs.board().side_to_move() {
  27.             Color::Black => -value,
  28.             Color::White => value,
  29.     };
  30.  
  31.     // step 1 - get the corresponding idx for legal moves
  32.  
  33.     let mut legal_moves: Vec<Move> = Vec::new();
  34.     bs.board().generate_moves(|moves| {
  35.         // Unpack dense move set into move list
  36.         legal_moves.extend(moves);
  37.         false
  38.     });
  39.  
  40.     let mut fm: Vec<Move> = Vec::new();
  41.     if bs.board().side_to_move() == Color::Black {
  42.         // flip move
  43.         for mv in &legal_moves {
  44.             fm.push(Move {
  45.                 from: mv.from.flip_rank(),
  46.                 to: mv.to.flip_rank(),
  47.                 promotion: mv.promotion,
  48.             })
  49.         }
  50.     } else {
  51.         fm = legal_moves.clone();
  52.     }
  53.  
  54.     legal_moves = fm;
  55.  
  56.     let mut idx_li: Vec<usize> = Vec::new();
  57.  
  58.     for mov in &legal_moves {
  59.         // let mov = format!("{}", mov);
  60.         if let Some(idx) = contents.iter().position(|x| mov == x) {
  61.             idx_li.push(idx as usize);
  62.         }
  63.     }
  64.  
  65.     // step 2 - using the idx in step 1, index all the policies involved
  66.     let mut pol_list: Vec<f32> = Vec::new();
  67.     for id in &idx_li {
  68.         pol_list.push(policy[*id]);
  69.     }
  70.  
  71.     // println!("{:?}", pol_list);
  72.  
  73.     // step 3 - softmax
  74.  
  75.     let sm = Tensor::from_slice(&pol_list);
  76.  
  77.     let sm = Tensor::softmax(&sm, 0, Kind::Float);
  78.  
  79.     let pol_list: Vec<f32> = Vec::try_from(sm).expect("Error");
  80.  
  81.     // println!("{:?}", pol_list);
  82.  
  83.     // println!("        V={}", &value);
  84.  
  85.     // step 4 - iteratively append nodes into class
  86.     let mut counter = 0;
  87.     let ct = tree.nodes.len();
  88.     for (mv, pol) in legal_moves.iter().zip(pol_list.iter()) {
  89.         tree.nodes[*selected_node_idx].eval_score = value;
  90.         // tree.nodes[*selected_node_idx].eval_score = 0.0;
  91.         let fm: Move;
  92.         if bs.board().side_to_move() == Color::Black {
  93.             // flip move
  94.             fm = Move {
  95.                 from: mv.from.flip_rank(),
  96.                 to: mv.to.flip_rank(),
  97.                 promotion: mv.promotion,
  98.             };
  99.         } else {
  100.             fm = *mv;
  101.         }
  102.         let mut child = Node::new(0.0, Some(*selected_node_idx), Some(fm));
  103.         // let mut child = Node::new(*pol, Some(*selected_node_idx), Some(fm));
  104.         // println!("{:?}, {:?}, {:?}", mv, child.policy, child.eval_score);
  105.         tree.nodes.push(child); // push child to the tree Vec<Node>
  106.         tree.nodes[*selected_node_idx].children.push(counter + ct); // push numbers
  107.         counter += 1
  108.     }
  109.     // println!("{:?}", tree.nodes.len());
  110.     idx_li
  111. }
  112.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement