Advertisement
CaptainSpaceCat

General Monte Carlo Tree Search

Jun 26th, 2019
187
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.85 KB | None | 0 0
  1. import random
  2.  
  3. # ===================== Game-Specific Info ===================== #
  4. players = ["x", "o"]
  5. defaultState = {
  6.     "board" : [
  7.     ["-","-","-"],
  8.     ["-","-","-"],
  9.     ["-","-","-"]
  10.     ],
  11.     "turn" : "x",
  12. }
  13.  
  14. def getLegalActions(state):
  15.     actions = []
  16.     for i in range(3):
  17.         for v in range(3):
  18.             if state["board"][i][v] == "-":
  19.                 actions.append(str(i) + str(v) + state["turn"])
  20.     return actions
  21.  
  22. def getNextState(state, action):
  23.     newState = {"board" : []}
  24.     for i in range(3):
  25.         newState["board"].append([])
  26.         for v in range(3):
  27.             newState["board"][i].append(state["board"][i][v])
  28.     newState["board"][int(action[0])][int(action[1])] = action[2]
  29.     if state["turn"] == "x":
  30.         newState["turn"] = "o"
  31.     else:
  32.         newState["turn"] = "x"
  33.     return newState
  34.  
  35. def isPlayerWon(state, player):
  36.     for i in range(3):
  37.         if state["board"][i][0] == player and state["board"][i][1] == player and state["board"][i][2] == player:
  38.             return True
  39.         if state["board"][0][i] == player and state["board"][1][i] == player and state["board"][2][i] == player:
  40.             return True
  41.     if state["board"][0][0] == player and state["board"][1][1] == player and state["board"][2][2] == player:
  42.         return True
  43.     if state["board"][2][0] == player and state["board"][1][1] == player and state["board"][0][2] == player:
  44.         return True
  45.     return False
  46.  
  47. def isTerminal(state):
  48.     filled = True
  49.     for i in range(3):
  50.         for v in range(3):
  51.             if state["board"][i][v] == "-":
  52.                 filled = False
  53.                 break
  54.     if filled or not getVictor(state) == None:
  55.         return True
  56.     return False
  57.  
  58. def getVictor(state):
  59.     for p in players:
  60.         if isPlayerWon(state, p):
  61.             return p
  62.     return None
  63.  
  64. def getScore(state, player):
  65.     if not isTerminal(state):
  66.         return 0
  67.     victor = getVictor(state)
  68.     if victor == None:
  69.         return 0.5
  70.     elif victor == player:
  71.         return 1
  72.     return 0
  73.  
  74.  
  75. # ===================== General MCTS Functionality ===================== #
  76. def getNewNode(state):
  77.     newNode = {}
  78.     newNode["state"] = state
  79.     newNode["parent"] = None
  80.     newNode["children"] = {}
  81.     #start with a prior of 1 win per player
  82.     newNode["scores"] = {}
  83.     newNode["total"] = 0
  84.     for p in players:
  85.         newNode["scores"][p] = float(0)
  86.     newNode["playoutFlag"] = False
  87.     return newNode
  88.  
  89. def getRandomChild(stateNode):
  90.     #if a node is unexplored, populate its children to reduce runtime
  91.     if len(stateNode["children"]) == 0:
  92.         #if it's a terminal node, we can't explore so just return
  93.         if isTerminal(stateNode["state"]):
  94.             print("Error: attempt to get children of terminal state")
  95.             return None
  96.         #otherwise, get every legal child and create a new node for it
  97.         actions = getLegalActions(stateNode["state"])
  98.         for act in actions:
  99.             newNode = getNewNode(getNextState(stateNode["state"], act))
  100.             newNode["parent"] = stateNode
  101.             stateNode["children"][act] = newNode
  102.     #now, choose a random child
  103.     return stateNode["children"][random.choice(list(stateNode["children"].keys()))]
  104.  
  105. def simulate(stateNode, depth):
  106.     if isTerminal(stateNode["state"]):
  107.         return
  108.     while stateNode["playoutFlag"]:
  109.         stateNode = getRandomChild(stateNode)
  110.         if isTerminal(stateNode["state"]):
  111.             break
  112.     scores = playout(stateNode, depth)
  113.     stateNode["playoutFlag"] = True
  114.  
  115.     while True:
  116.         stateNode["total"] += 1
  117.         test = float(0)
  118.         for p in players:
  119.             stateNode["scores"][p] += scores[p]
  120.             test += scores[p]
  121.         if not test == 1:
  122.             print("error, test is " + str(test))
  123.         if stateNode["parent"] == None:
  124.             break
  125.         stateNode = stateNode["parent"]
  126.    
  127.  
  128. def playout(stateNode, depth):
  129.     for i in range(depth):
  130.         if isTerminal(stateNode["state"]):
  131.             break
  132.         stateNode = getRandomChild(stateNode)
  133.        
  134.     scores = {}
  135.     for p in players:
  136.         scores[p] = getScore(stateNode["state"], p)
  137.     return scores
  138.  
  139. def getBestAction(stateNode):
  140.     bestAction = None
  141.     bestRatio = None
  142.     player = stateNode["state"]["turn"]
  143.     for a, c in stateNode["children"].items():
  144.         ratio = float(c["scores"][player])/float(c["total"])
  145.         if bestRatio == None or ratio > bestRatio:
  146.             bestRatio = ratio
  147.             bestAction = a
  148.     return bestAction
  149.  
  150. def pruneHigherNodes(stateNode):
  151.     #prunes all nodes above this node and sets this node as the root
  152.     pass
  153.  
  154.  
  155. def printState(state):
  156.     for i in range(3):
  157.         line = ""
  158.         for v in range(3):
  159.             line += state["board"][i][v]
  160.         print(line)
  161.     print("> " + state["turn"] + "\n")
  162.  
  163. def printTree(node, depth = -1):
  164.     if depth == 0:
  165.         return
  166.     printState(node["state"])
  167.     print(str(node["scores"]) + "/" + str(node["total"]) + "\n")
  168.     for a, c in node["children"].items():
  169.         printTree(c, depth-1)
  170.  
  171. MAX_DEPTH = 100
  172. NUM_PLAYOUTS = 10000
  173. rootNode = getNewNode(defaultState)
  174. currentNode = rootNode
  175. while not isTerminal(currentNode["state"]):
  176.     for i in range(NUM_PLAYOUTS):
  177.         simulate(currentNode, MAX_DEPTH)
  178.     bestAction = getBestAction(currentNode)
  179.     currentNode = currentNode["children"][bestAction]
  180.     printState(currentNode["state"])
  181.     #break
  182.  
  183. printTree(rootNode, 2)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement