Advertisement
alkkofficial

Untitled

Jul 22nd, 2023 (edited)
126
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.62 KB | None | 0 0
  1. use cozy_chess::*;
  2. use core::panic;
  3. use std::env;
  4. use rand::*;
  5. use cozy_chess::GameStatus::Ongoing;
  6.  
  7. #[derive(Debug, Clone, PartialEq)]
  8. struct TreeNode {
  9. M: f64, // eval
  10. V: i32, // number of nodes visited
  11. visitedMovesAndNodes: Vec<(cozy_chess::Move, Box<TreeNode>)>,
  12. nonVisitedLegalMoves: Vec<cozy_chess::Move>,
  13. parent: Option<Box<TreeNode>>,
  14. board: Board,
  15. }
  16.  
  17. impl TreeNode {
  18. fn isMctsLeafNode(&self) -> bool {
  19. // println!("{}",!self.nonVisitedLegalMoves.is_empty());
  20. return !self.nonVisitedLegalMoves.is_empty()
  21. }
  22.  
  23. fn isTerminalNode(&self) -> bool {
  24. // println!("{}",self.nonVisitedLegalMoves.is_empty() && self.visitedMovesAndNodes.is_empty());
  25. return self.nonVisitedLegalMoves.is_empty() && self.visitedMovesAndNodes.is_empty()
  26. }
  27. }
  28.  
  29.  
  30.  
  31. fn ucbscore(node: &TreeNode, parent: &TreeNode) -> f64 {
  32. let pv = parent.V as f64;
  33. let nv = node.V as f64;
  34. let nm = node.M as f64;
  35. let val: f64 = (nm/nv) + 1.4142135624 * ((pv).log(2.718281828459) / nv).sqrt();
  36. return val;
  37. }
  38.  
  39. fn select(node: &mut TreeNode) -> TreeNode {
  40. if node.isMctsLeafNode() || node.isTerminalNode() {
  41. println!("NOPE");
  42. return *node;
  43. } else {
  44. println!("YEP");
  45. let mut maxUctChild = None; // Use Option type
  46. let mut maxUctValue = -10000.0;
  47. for (_, child) in &node.visitedMovesAndNodes {
  48. let uctValChild = ucbscore(child, &node);
  49. println!("{}", uctValChild);
  50. if uctValChild > maxUctValue {
  51. maxUctChild = Some(child);
  52. maxUctValue = uctValChild;
  53. }
  54. }
  55. if let Some(maxUctChild) = maxUctChild {
  56. println!("SOME");
  57. return select (&mut(**maxUctChild).clone()); // Dereference twice to get TreeNode from Option<Box<TreeNode>>
  58. } else {
  59. println!("NONE");
  60. panic!("NO BEST CHILD FOUND");
  61. }
  62. }
  63. }
  64. fn expand(node: &mut TreeNode) -> TreeNode {
  65. let moveToExpand = node.nonVisitedLegalMoves.remove(node.nonVisitedLegalMoves.len()-1);
  66. let mut board = node.board.clone();
  67. board.play(moveToExpand);
  68. println!("{:?}",moveToExpand);
  69. let mut legal_moves: Vec<Move> = Vec::new();
  70. board.generate_moves(|moves| {
  71. legal_moves.extend(moves);
  72. false
  73. });
  74. let childNode = Box::new(TreeNode {
  75. M: 0.0,
  76. V: 0,
  77. visitedMovesAndNodes: Vec::new(),
  78. nonVisitedLegalMoves: legal_moves,
  79. parent: Some(Box::new(node.clone())),
  80. board,
  81. });
  82. node.visitedMovesAndNodes.push((moveToExpand, childNode.clone()));
  83. return *childNode
  84. }
  85.  
  86. fn simulate(node: &mut TreeNode) -> f64 {
  87. let mut board = node.board.clone();
  88. while board.status() == Ongoing {
  89. let mut legal_moves: Vec<Move> = Vec::new();
  90. board.generate_moves(|moves| {
  91. legal_moves.extend(moves);
  92. false
  93. });
  94. let move_idx = rand::thread_rng().gen_range(0..legal_moves.len());
  95. board.play(legal_moves[move_idx]);
  96. }
  97. let payout = if board.status() == GameStatus::Drawn {
  98. 0.5
  99. } else {
  100.  
  101. let win = !board.side_to_move();
  102. match win {
  103. Color::White => 1.0,
  104. Color::Black => 0.0,
  105. }
  106. };
  107. return payout;
  108. }
  109.  
  110. fn backpropagate(mut node: TreeNode, payout: f64) {
  111. node.M += payout;
  112. node.V += 1;
  113. if let Some(parent) = node.parent {
  114. return backpropagate(*parent, payout);
  115. }
  116.  
  117. }
  118.  
  119. fn main() {
  120. env::set_var("RUST_BACKTRACE", "1");
  121. let board = Board::default();
  122. let mut legal_moves: Vec<Move> = Vec::new();
  123. board.generate_moves(|moves| {
  124. legal_moves.extend(moves);
  125. false
  126. });
  127. let mut root = TreeNode {
  128. M: 0.0,
  129. V: 0,
  130. visitedMovesAndNodes: Vec::new(),
  131. nonVisitedLegalMoves: legal_moves,
  132. parent: None,
  133. board,
  134. };
  135. println!("{}",root.nonVisitedLegalMoves.len());
  136. for i in 0..1000000 {
  137. if i % 10 == 0 {
  138. println!("{}", i)
  139. }
  140. if i % 100 == 0 {
  141. for (m, child) in &root.visitedMovesAndNodes {
  142. println!("move:");
  143. println!("{}",m);
  144. println!("{}",child.M);
  145. println!("{}",child.V);
  146. }
  147. }
  148. let mut node = select(&mut root);
  149. if !node.isTerminalNode() {
  150. node = expand(&mut node)
  151. }
  152. let payout = simulate(&mut node);
  153. backpropagate(node, payout);
  154.  
  155. }
  156.  
  157.  
  158. }
  159.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement