Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- use cozy_chess::*;
- use core::panic;
- use std::env;
- use rand::*;
- use cozy_chess::GameStatus::Ongoing;
- #[derive(Debug, Clone, PartialEq)]
- struct TreeNode {
- M: f64, // eval
- V: i32, // number of nodes visited
- visitedMovesAndNodes: Vec<(cozy_chess::Move, Box<TreeNode>)>,
- nonVisitedLegalMoves: Vec<cozy_chess::Move>,
- parent: Option<Box<TreeNode>>,
- board: Board,
- }
- impl TreeNode {
- fn isMctsLeafNode(&self) -> bool {
- // println!("{}",!self.nonVisitedLegalMoves.is_empty());
- return !self.nonVisitedLegalMoves.is_empty()
- }
- fn isTerminalNode(&self) -> bool {
- // println!("{}",self.nonVisitedLegalMoves.is_empty() && self.visitedMovesAndNodes.is_empty());
- return self.nonVisitedLegalMoves.is_empty() && self.visitedMovesAndNodes.is_empty()
- }
- }
- fn ucbscore(node: &TreeNode, parent: &TreeNode) -> f64 {
- let pv = parent.V as f64;
- let nv = node.V as f64;
- let nm = node.M as f64;
- let val: f64 = (nm/nv) + 1.4142135624 * ((pv).log(2.718281828459) / nv).sqrt();
- return val;
- }
- fn select(node: &mut TreeNode) -> TreeNode {
- if node.isMctsLeafNode() || node.isTerminalNode() {
- println!("NOPE");
- return *node;
- } else {
- println!("YEP");
- let mut maxUctChild = None; // Use Option type
- let mut maxUctValue = -10000.0;
- for (_, child) in &node.visitedMovesAndNodes {
- let uctValChild = ucbscore(child, &node);
- println!("{}", uctValChild);
- if uctValChild > maxUctValue {
- maxUctChild = Some(child);
- maxUctValue = uctValChild;
- }
- }
- if let Some(maxUctChild) = maxUctChild {
- println!("SOME");
- return select (&mut(**maxUctChild).clone()); // Dereference twice to get TreeNode from Option<Box<TreeNode>>
- } else {
- println!("NONE");
- panic!("NO BEST CHILD FOUND");
- }
- }
- }
- fn expand(node: &mut TreeNode) -> TreeNode {
- let moveToExpand = node.nonVisitedLegalMoves.remove(node.nonVisitedLegalMoves.len()-1);
- let mut board = node.board.clone();
- board.play(moveToExpand);
- println!("{:?}",moveToExpand);
- let mut legal_moves: Vec<Move> = Vec::new();
- board.generate_moves(|moves| {
- legal_moves.extend(moves);
- false
- });
- let childNode = Box::new(TreeNode {
- M: 0.0,
- V: 0,
- visitedMovesAndNodes: Vec::new(),
- nonVisitedLegalMoves: legal_moves,
- parent: Some(Box::new(node.clone())),
- board,
- });
- node.visitedMovesAndNodes.push((moveToExpand, childNode.clone()));
- return *childNode
- }
- fn simulate(node: &mut TreeNode) -> f64 {
- let mut board = node.board.clone();
- while board.status() == Ongoing {
- let mut legal_moves: Vec<Move> = Vec::new();
- board.generate_moves(|moves| {
- legal_moves.extend(moves);
- false
- });
- let move_idx = rand::thread_rng().gen_range(0..legal_moves.len());
- board.play(legal_moves[move_idx]);
- }
- let payout = if board.status() == GameStatus::Drawn {
- 0.5
- } else {
- let win = !board.side_to_move();
- match win {
- Color::White => 1.0,
- Color::Black => 0.0,
- }
- };
- return payout;
- }
- fn backpropagate(mut node: TreeNode, payout: f64) {
- node.M += payout;
- node.V += 1;
- if let Some(parent) = node.parent {
- return backpropagate(*parent, payout);
- }
- }
- fn main() {
- env::set_var("RUST_BACKTRACE", "1");
- let board = Board::default();
- let mut legal_moves: Vec<Move> = Vec::new();
- board.generate_moves(|moves| {
- legal_moves.extend(moves);
- false
- });
- let mut root = TreeNode {
- M: 0.0,
- V: 0,
- visitedMovesAndNodes: Vec::new(),
- nonVisitedLegalMoves: legal_moves,
- parent: None,
- board,
- };
- println!("{}",root.nonVisitedLegalMoves.len());
- for i in 0..1000000 {
- if i % 10 == 0 {
- println!("{}", i)
- }
- if i % 100 == 0 {
- for (m, child) in &root.visitedMovesAndNodes {
- println!("move:");
- println!("{}",m);
- println!("{}",child.M);
- println!("{}",child.V);
- }
- }
- let mut node = select(&mut root);
- if !node.isTerminalNode() {
- node = expand(&mut node)
- }
- let payout = simulate(&mut node);
- backpropagate(node, payout);
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement