Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import random
- # ===================== Game-Specific Info ===================== #
- players = ["x", "o"]
- defaultState = {
- "board" : [
- ["-","-","-"],
- ["-","-","-"],
- ["-","-","-"]
- ],
- "turn" : "x",
- }
- def getLegalActions(state):
- actions = []
- for i in range(3):
- for v in range(3):
- if state["board"][i][v] == "-":
- actions.append(str(i) + str(v) + state["turn"])
- return actions
- def getNextState(state, action):
- newState = {"board" : []}
- for i in range(3):
- newState["board"].append([])
- for v in range(3):
- newState["board"][i].append(state["board"][i][v])
- newState["board"][int(action[0])][int(action[1])] = action[2]
- if state["turn"] == "x":
- newState["turn"] = "o"
- else:
- newState["turn"] = "x"
- return newState
- def isPlayerWon(state, player):
- for i in range(3):
- if state["board"][i][0] == player and state["board"][i][1] == player and state["board"][i][2] == player:
- return True
- if state["board"][0][i] == player and state["board"][1][i] == player and state["board"][2][i] == player:
- return True
- if state["board"][0][0] == player and state["board"][1][1] == player and state["board"][2][2] == player:
- return True
- if state["board"][2][0] == player and state["board"][1][1] == player and state["board"][0][2] == player:
- return True
- return False
- def isTerminal(state):
- filled = True
- for i in range(3):
- for v in range(3):
- if state["board"][i][v] == "-":
- filled = False
- break
- if filled or not getVictor(state) == None:
- return True
- return False
- def getVictor(state):
- for p in players:
- if isPlayerWon(state, p):
- return p
- return None
- def getScore(state, player):
- if not isTerminal(state):
- return 0
- victor = getVictor(state)
- if victor == None:
- return 0.5
- elif victor == player:
- return 1
- return 0
- # ===================== General MCTS Functionality ===================== #
- def getNewNode(state):
- newNode = {}
- newNode["state"] = state
- newNode["parent"] = None
- newNode["children"] = {}
- #start with a prior of 1 win per player
- newNode["scores"] = {}
- newNode["total"] = 0
- for p in players:
- newNode["scores"][p] = float(0)
- newNode["playoutFlag"] = False
- return newNode
- def getRandomChild(stateNode):
- #if a node is unexplored, populate its children to reduce runtime
- if len(stateNode["children"]) == 0:
- #if it's a terminal node, we can't explore so just return
- if isTerminal(stateNode["state"]):
- print("Error: attempt to get children of terminal state")
- return None
- #otherwise, get every legal child and create a new node for it
- actions = getLegalActions(stateNode["state"])
- for act in actions:
- newNode = getNewNode(getNextState(stateNode["state"], act))
- newNode["parent"] = stateNode
- stateNode["children"][act] = newNode
- #now, choose a random child
- return stateNode["children"][random.choice(list(stateNode["children"].keys()))]
- def simulate(stateNode, depth):
- if isTerminal(stateNode["state"]):
- return
- while stateNode["playoutFlag"]:
- stateNode = getRandomChild(stateNode)
- if isTerminal(stateNode["state"]):
- break
- scores = playout(stateNode, depth)
- stateNode["playoutFlag"] = True
- while True:
- stateNode["total"] += 1
- test = float(0)
- for p in players:
- stateNode["scores"][p] += scores[p]
- test += scores[p]
- if not test == 1:
- print("error, test is " + str(test))
- if stateNode["parent"] == None:
- break
- stateNode = stateNode["parent"]
- def playout(stateNode, depth):
- for i in range(depth):
- if isTerminal(stateNode["state"]):
- break
- stateNode = getRandomChild(stateNode)
- scores = {}
- for p in players:
- scores[p] = getScore(stateNode["state"], p)
- return scores
- def getBestAction(stateNode):
- bestAction = None
- bestRatio = None
- player = stateNode["state"]["turn"]
- for a, c in stateNode["children"].items():
- ratio = float(c["scores"][player])/float(c["total"])
- if bestRatio == None or ratio > bestRatio:
- bestRatio = ratio
- bestAction = a
- return bestAction
- def pruneHigherNodes(stateNode):
- #prunes all nodes above this node and sets this node as the root
- pass
- def printState(state):
- for i in range(3):
- line = ""
- for v in range(3):
- line += state["board"][i][v]
- print(line)
- print("> " + state["turn"] + "\n")
- def printTree(node, depth = -1):
- if depth == 0:
- return
- printState(node["state"])
- print(str(node["scores"]) + "/" + str(node["total"]) + "\n")
- for a, c in node["children"].items():
- printTree(c, depth-1)
- MAX_DEPTH = 100
- NUM_PLAYOUTS = 10000
- rootNode = getNewNode(defaultState)
- currentNode = rootNode
- while not isTerminal(currentNode["state"]):
- for i in range(NUM_PLAYOUTS):
- simulate(currentNode, MAX_DEPTH)
- bestAction = getBestAction(currentNode)
- currentNode = currentNode["children"][bestAction]
- printState(currentNode["state"])
- #break
- printTree(rootNode, 2)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement