Advertisement
alkkofficial

Rust MCTS draft

Jul 21st, 2023 (edited)
105
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Rust 4.28 KB | None | 0 0
  1. use cozy_chess::*;
  2. use core::panic;
  3. use std::env;
  4. use rand::*;
  5. use cozy_chess::GameStatus::Ongoing;
  6.  
  7. #[derive(Debug, Clone, PartialEq)]
  8. struct TreeNode {
  9.     M: f64, // eval
  10.     V: i32, // number of nodes visited
  11.     visitedMovesAndNodes: Vec<(cozy_chess::Move, Box<TreeNode>)>,
  12.     nonVisitedLegalMoves: Vec<cozy_chess::Move>,
  13.     parent: Option<Box<TreeNode>>,
  14.     board: Board,
  15. }
  16.  
  17. impl TreeNode {
  18.     fn isMctsLeafNode(&self) -> bool {
  19.         // println!("{}",!self.nonVisitedLegalMoves.is_empty());
  20.         return !self.nonVisitedLegalMoves.is_empty()
  21.     }
  22.  
  23.     fn isTerminalNode(&self) -> bool {
  24.         // println!("{}",self.nonVisitedLegalMoves.is_empty() && self.visitedMovesAndNodes.is_empty());
  25.         return self.nonVisitedLegalMoves.is_empty() && self.visitedMovesAndNodes.is_empty()
  26.     }
  27. }
  28.  
  29. fn uctvalue(node: &TreeNode, parent: &TreeNode) -> f64 {
  30.     let pv = parent.V as f64;
  31.     let nv = node.V as f64;
  32.     let nm = node.M as f64;
  33.     let val: f64 = (nm/nv) + 1.4142135624 * ((pv).log(2.718281828459) / nv).sqrt();
  34.  
  35.     return val;
  36. }
  37.  
  38. fn select(node: TreeNode) -> TreeNode {
  39.     if node.isMctsLeafNode() || node.isTerminalNode() {
  40.         return node;
  41.     } else {
  42.         let mut maxUctChild: Option<&Box<TreeNode>> = None; // Use Option type
  43.         let mut maxUctValue = -10000.0;
  44.         for (_, child) in node.visitedMovesAndNodes.iter() {
  45.             let uctValChild = uctvalue(child, &node);
  46.             if uctValChild > maxUctValue {
  47.                 maxUctChild = Some(child); // Assign Some(value)
  48.                 maxUctValue = uctValChild;
  49.             }
  50.         }
  51.         if let Some(maxUctChild) = maxUctChild {
  52.             return select((**maxUctChild).clone()); // Dereference twice to get TreeNode from Option<Box<TreeNode>>
  53.         } else {
  54.             panic!("NO BEST CHILD FOUND");
  55.         }
  56.     }
  57. }
  58. fn expand(mut node: TreeNode) -> TreeNode {
  59.     let moveToExpand = node.nonVisitedLegalMoves.remove(node.nonVisitedLegalMoves.len()-1);
  60.     let mut board = node.board.clone();
  61.     board.play(moveToExpand);
  62.     println!("{}", moveToExpand);
  63.     let childNode = Box::new(TreeNode {
  64.         M: 0.0,
  65.         V: 0,
  66.         visitedMovesAndNodes: Vec::new(),
  67.         nonVisitedLegalMoves: Vec::new(),
  68.         parent: Some(Box::new(node.clone())),
  69.         board,
  70.     });
  71.     node.visitedMovesAndNodes.push((moveToExpand, childNode.clone()));
  72.     return *childNode
  73. }
  74.  
  75. fn simulate(node: TreeNode) -> f64 {
  76.     let mut board = node.board.clone();
  77.     while board.status() == Ongoing {
  78.         let mut legal_moves: Vec<Move> = Vec::new();
  79.         board.generate_moves(|moves| {
  80.             legal_moves.extend(moves);
  81.             false
  82.         });
  83.         let move_idx = rand::thread_rng().gen_range(0..legal_moves.len());
  84.         board.play(legal_moves[move_idx]);
  85.     }
  86.     let payout = if board.status() == GameStatus::Drawn {
  87.         0.5
  88.     } else {
  89.        
  90.         let win = !board.side_to_move();
  91.         match win {
  92.             Color::White => 1.0,
  93.             Color::Black => 0.0,
  94.         }
  95.     };
  96.     return payout;
  97. }
  98.  
  99. fn backpropagate(mut node: TreeNode, payout: f64) {
  100.     node.M += payout;
  101.     node.V += 1;
  102.  
  103.     if let Some(parent) = node.parent {
  104.         return backpropagate(*parent, payout);
  105.     }
  106. }
  107.  
  108. fn main() {
  109.     env::set_var("RUST_BACKTRACE", "1");
  110.     let board = Board::default();
  111.     let root = TreeNode {
  112.         M: 0.0,
  113.         V: 0,
  114.         visitedMovesAndNodes: Vec::new(),
  115.         nonVisitedLegalMoves: Vec::new(),
  116.         parent: None,
  117.         board,
  118.     };
  119.  
  120.     for i in 0..5000 {
  121.         if i % 10 == 0 {
  122.             println!("{}", i)
  123.         }
  124.         if i % 100 == 0 {
  125.             for (m, child) in &root.visitedMovesAndNodes {
  126.                 println!("move:");
  127.                 println!("{}",m);
  128.                 println!("{}",child.M);
  129.                 println!("{}",child.V);
  130.             }
  131.         }
  132.         let mut node = select(root.clone());
  133.         if !node.isTerminalNode() {
  134.             node = expand(node);
  135.         }
  136.         let payout = simulate(node.clone());
  137.         backpropagate(node, payout);
  138.        
  139.     }
  140.  
  141.     for (m, child) in root.visitedMovesAndNodes{
  142.         println!("{}", m);
  143.         println!("{}",child.M);
  144.         println!("{}",child.V);
  145.     }
  146.  
  147. }
  148.  
  149.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement