Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import numpy as np
- import seaborn as sns
- import matplotlib.pyplot as plt
- def create_maze():
- maze = np.zeros((39, 39), dtype=float)
- maze_idx = [i for i, j in np.ndenumerate(np.zeros((39, 39), dtype=int))]
- wall = -4
- maze[0, :] = wall
- maze[-1, :] = wall
- maze[:, 0] = wall
- maze[:, -1] = wall
- idx = [
- (4, 1), (8, 1), (12, 1), (18, 1), (20, 1), (28, 1), (34, 1),
- (2, 2), (4, 2), (6, 2), (8, 2), (10, 2), (12, 2), (14, 2), (16, 2), (17, 2), (18, 2), (20, 2), (22, 2),
- (23, 2), (24, 2), (25, 2), (26, 2), (28, 2), (30, 2), (31, 2), (32, 2), (34, 2), (36, 2),
- (2, 3), (6, 3), (8, 3), (10, 3), (14, 3), (20, 3), (26, 3), (28, 3), (32, 3), (36, 3),
- (2, 4), (3, 4), (4, 4), (5, 4), (6, 4), (7, 4), (8, 4), (10, 4), (11, 4), (12, 4), (13, 4), (14, 4), (16, 4),
- (18, 4), (19, 4), (20, 4), (22, 4), (23, 4), (24, 4), (26, 4), (27, 4,), (28, 4), (29, 4), (30, 4),
- (32, 4), (33, 4), (34, 4), (35, 4), (36, 4),
- (4, 5), (8, 5), (10, 5), (12, 5), (16, 5), (22, 5), (24, 5), (26, 5), (32, 5), (36, 5),
- (1, 6), (2, 6), (4, 6), (6, 6), (8, 6), (10, 6), (12, 6), (14, 6), (15, 6), (16, 6), (17, 6), (18, 6), (19, 6),
- (20, 6), (21, 6), (22, 6), (24, 6), (26, 6), (28, 6), (29, 6), (30, 6), (32, 6), (33, 6), (34, 6), (36, 6),
- (4, 7), (6, 7), (10, 7), (12, 7), (18, 7), (24, 7), (28, 7), (30, 7), (34, 7), (36, 7),
- (2, 8), (3, 8), (4, 8), (6, 8), (8, 8), (9, 8), (10, 8), (12, 8), (14, 8), (15, 8), (16, 8), (18, 8), (20, 8),
- (21, 8), (22, 8), (23, 8), (24, 8), (25, 8), (26, 8), (27, 8), (28, 8), (30, 8), (32, 8), (34, 8), (36, 8),
- (2, 9), (6, 9), (8, 9), (12, 9), (16, 9), (24, 9), (28, 9), (32, 9), (36, 9),
- (2, 10), (4, 10), (5, 10), (6, 10), (8, 10), (10, 10), (11, 10), (12, 10), (13, 10), (14, 10), (16, 10), (17, 10),
- (18, 10), (20, 10), (21, 10), (22, 10), (24, 10), (26, 10), (28, 10), (30, 10), (31, 10), (32, 10), (33, 10),
- (34, 10), (36, 10),
- (2, 11), (4, 11), (8, 11), (12, 11), (16, 11), (22, 11), (24, 11), (26, 11), (28, 11), (30, 11), (34, 11), (36, 11),
- (2, 12), (3, 12), (4, 12), (6, 12), (7, 12), (8, 12), (9, 12), (10, 12), (12, 12), (14, 12), (15, 12), (16, 12),
- (17, 12), (18, 12), (20, 12), (22, 12), (23, 12), (24, 12), (26, 12), (28, 12), (30, 12), (32, 12), (34, 12),
- (36, 12), (37, 12),
- (6, 13), (12, 13), (18, 13), (20, 13), (24, 13), (26, 13), (30, 13), (32, 13),
- (2, 14), (4, 14), (6, 14), (8, 14), (9, 14), (10, 14), (12, 14), (13, 14), (14, 14), (15, 14), (16, 14), (18, 14),
- (20, 14), (21, 14), (22, 14), (24, 14), (25, 14), (26, 14), (27, 14), (28, 14), (29, 14), (30, 14), (32, 14),
- (33, 14), (34, 14), (35, 14), (36, 14),
- (2, 15), (4, 15), (6, 15), (8, 15), (10, 15), (16, 15), (18, 15), (20, 15), (22, 15), (24, 15), (30, 15), (32, 15),
- (36, 15),
- (1, 16), (2, 16), (4, 16), (5, 16), (6, 16), (8, 16), (10, 16), (11, 16), (12, 16), (13, 16), (14, 16), (16, 16),
- (17, 16), (18, 16), (20, 16), (22, 16), (24, 16), (26, 16), (27, 16), (28, 16), (30, 16), (31, 16), (32, 16),
- (34, 16), (35, 16), (36, 16),
- (4, 17), (14, 17), (18, 17), (22, 17), (26, 17), (32, 17),
- (2, 18), (3, 18), (4, 18), (6, 18), (7, 18), (8, 18), (9, 18), (10, 18), (11, 18), (12, 18), (14, 18), (16, 18),
- (18, 18), (19, 18), (20, 18), (21, 18), (22, 18), (23, 18), (24, 18), (25, 18), (26, 18), (27, 18), (28, 18),
- (29, 18), (30, 18), (32, 18), (33, 18), (34, 18), (36, 18),
- (2, 19), (6, 19), (8, 19), (10, 19), (14, 19), (16, 19), (22, 19), (30, 19), (34, 19), (36, 19),
- (2, 20), (4, 20), (5, 20), (6, 20), (8, 20), (10, 20), (11, 20), (12, 20), (13, 20), (14, 20), (16, 20), (17, 20),
- (18, 20), (19, 20), (20, 20), (22, 20), (24, 20), (25, 20), (26, 20), (27, 20), (28, 20), (30, 20), (32, 20),
- (34, 20), (36, 20), (37, 20),
- (2, 21), (4, 21), (8, 21), (12, 21), (16, 21), (20, 21), (22, 21), (28, 21), (32, 21), (34, 21),
- (2, 22), (4, 22), (5, 22), (6, 22), (8, 22), (9, 22), (10, 22), (12, 22), (13, 22), (14, 22), (16, 22), (18, 22),
- (20, 22), (22, 22), (23, 22), (24, 22), (25, 22), (26, 22), (28, 22), (29, 22), (30, 22), (31, 22), (32, 22),
- (34, 22), (35, 22), (36, 22),
- (2, 23), (6, 23), (10, 23), (14, 23), (18, 23), (20, 23), (24, 23), (30, 23), (34, 23),
- (2, 24), (3, 24), (4, 24), (6, 24), (7, 24), (8, 24), (10, 24), (11, 24), (12, 24), (14, 24), (15, 24), (16, 24),
- (17, 24), (18, 24), (20, 24), (21, 24), (22, 24), (24, 24), (25, 24), (26, 24), (27, 24), (28, 24), (29, 24),
- (30, 24), (32, 24), (33, 24), (34, 24), (36, 24), (37, 24),
- (4, 25), (12, 25), (18, 25), (20, 25), (22, 25), (28, 25), (32, 25), (34, 25),
- (1, 26), (2, 26), (4, 26), (5, 26), (6, 26), (8, 26), (9, 26), (10, 26), (11, 26), (12, 26), (13, 26), (14, 26),
- (15, 26), (16, 26), (18, 26), (20, 26), (22, 26), (23, 26), (24, 26), (26, 26), (27, 26), (28, 26), (30, 26),
- (31, 26), (32, 26), (34, 26), (35, 26), (36, 26),
- (4, 27), (6, 27), (8, 27), (12, 27), (14, 27), (18, 27), (24, 27), (28, 27), (30, 27), (34, 27),
- (2, 28), (3, 28), (4, 28), (6, 28), (8, 28), (10, 28), (12, 28), (14, 28), (16, 28), (17, 28), (18, 28), (19, 28),
- (20, 28), (21, 28), (22, 28), (23, 28), (24, 28), (25, 28), (26, 28), (28, 28), (30, 28), (32, 28), (33, 28),
- (34, 28), (36, 28), (37, 28),
- (4, 29), (6, 29), (10, 29), (14, 29), (18, 29), (20, 29), (30, 29), (36, 29),
- (2, 30), (4, 30), (6, 30), (8, 30), (10, 30), (11, 30), (12, 30), (13, 30), (14, 30), (15, 30), (16, 30), (18, 30),
- (20, 30), (22, 30), (23, 30), (24, 30), (25, 30), (26, 30), (27, 30), (28, 30), (29, 30), (30, 30), (31, 30),
- (32, 30), (34, 30), (35, 30), (36, 30),
- (2, 31), (4, 31), (6, 31), (8, 31), (12, 31), (16, 31), (18, 31), (22, 31), (26, 31), (30, 31), (34, 31),
- (1, 32), (2, 32), (4, 32), (6, 32), (7, 32), (8, 32), (9, 32), (10, 32), (12, 32), (14, 32), (15, 32), (16, 32),
- (18, 32), (19, 32), (20, 32), (21, 32), (22, 32), (24, 32), (26, 32), (28, 32), (29, 32), (30, 32), (32, 32),
- (33, 32), (34, 32), (36, 32), (37, 32),
- (4, 33), (10, 33), (12, 33), (18, 33), (24, 33), (28, 33), (32, 33),
- (2, 34), (3, 34), (4, 34), (5, 34), (6, 34), (7, 34), (8, 34), (10, 34), (12, 34), (14, 34), (15, 34), (16, 34),
- (18, 34), (20, 34), (21, 34), (22, 34), (23, 34), (24, 34), (25, 34), (26, 34), (28, 34), (30, 34), (31, 34),
- (32, 34), (33, 34), (34, 34), (35, 34), (36, 34),
- (4, 35), (8, 35), (10, 35), (16, 35), (22, 35), (26, 35), (28, 35), (34, 35),
- (2, 36), (4, 36), (6, 36), (8, 36), (10, 36), (11, 36), (12, 36), (14, 36), (16, 36), (17, 36), (18, 36),
- (20, 36), (22, 36), (24, 36), (26, 36), (27, 36), (28, 36), (29, 36), (30, 36), (32, 36), (34, 36), (36, 36),
- (2, 37), (6, 37), (8, 37), (14, 37), (18, 37), (20, 37), (24, 37), (30, 37), (32, 37), (36, 37),
- ]
- for i in idx: maze[i] = wall
- # Додавання виходу
- exit = [(19, 38), (31, 38)]
- maze[exit[0]] = -2
- maze[exit[1]] = -2
- # Додавання стартової позиції
- start_pos = (19, 0)
- maze[start_pos] = -1
- idx2 = [
- (19, 1), (19, 2), (19, 3), (18, 3), (17, 3), (17, 4), (17, 5), (18, 5), (19, 5), (20, 5), (21, 5), (21, 4),
- (21, 3), (22, 3), (23, 3), (24, 3), (25, 3), (25, 4), (25, 5), (25, 6), (25, 7), (26, 7), (27, 7), (27, 6),
- (27, 5), (28, 5), (29, 5), (30, 5), (31, 5), (31, 6), (31, 7), (32, 7), (33, 7), (33, 8), (33, 9), (34, 9),
- (35, 9), (35, 10), (35, 11), (35, 12), (35, 13), (36, 13), (37, 13), (37, 14), (37, 15), (37, 16), (37, 17),
- (36, 17), (35, 17), (35, 18), (35, 19), (35, 20), (35, 21), (36, 21), (37, 21), (37, 22), (37, 23), (36, 23),
- (35, 23), (35, 24), (35, 25), (36, 25), (37, 25), (37, 26), (37, 27), (36, 27), (35, 27), (35, 28), (35, 29),
- (34, 29), (33, 29), (33, 30), (33, 31), (32, 31), (31, 31), (31, 32), (31, 33), (30, 33), (29, 33), (29, 34),
- (29, 35), (30, 35), (31, 35), (31, 36), (31, 37)
- ]
- idx1 = [
- (19, 1), (19, 2), (19, 3), (18, 3), (17, 3), (16, 3), (15, 3), (15, 4), (15, 5), (14, 5), (13, 5), (13, 6),
- (13, 7), (14, 7), (15, 7), (16, 7), (17, 7), (17, 8), (17, 9), (18, 9), (19, 9), (19, 10), (19, 11), (20, 11),
- (21, 11), (21, 12), (21, 13), (22, 13), (23, 13), (23, 14), (23, 15), (23, 16), (23, 17), (24, 17), (25, 17),
- (25, 16), (25, 15), (26, 15), (27, 15), (28, 15), (29, 15), (29, 16), (29, 17), (30, 17), (31, 17), (31, 18),
- (31, 19), (32, 19), (33, 19), (33, 20), (33, 21), (33, 22), (33, 23), (32, 23), (31, 23), (31, 24), (31, 25),
- (30, 25), (29, 25), (29, 26), (29, 27), (29, 28), (29, 29), (28, 29), (27, 29), (27, 28), (27, 27), (26, 27),
- (25, 27), (25, 26), (25, 25), (24, 25), (23, 25), (23, 24), (23, 23), (22, 23), (21, 23), (21, 22), (21, 21),
- (21, 20), (21, 19), (20, 19), (19, 19), (18, 19), (17, 19), (17, 18), (17, 17), (16, 17), (15, 17), (15, 16),
- (15, 15), (14, 15), (13, 15), (12, 15), (11, 15), (11, 14), (11, 13), (10, 13), (9, 13), (8, 13), (7, 13),
- (7, 14), (7, 15), (7, 16), (7, 17), (6, 17), (5, 17), (5, 18), (5, 19), (4, 19), (3, 19), (3, 20), (3, 21),
- (3, 22), (3, 23), (4, 23), (5, 23), (5, 24), (5, 25), (6, 25), (7, 25), (7, 26), (7, 27), (7, 28), (7, 29),
- (8, 29), (9, 29), (9, 30), (9, 31), (10, 31), (11, 31), (11, 32), (11, 33), (11, 34), (11, 35), (12, 35),
- (13, 35), (13, 34), (13, 33), (14, 33), (15, 33), (16, 33), (17, 33), (17, 34), (17, 35), (18, 35), (19, 35),
- (19, 36), (19, 37)
- ]
- maze2 = np.copy(maze)
- for i in idx1: maze2[i] = -0.5
- for i in idx2: maze2[i] = -0.5
- return maze, maze_idx, maze2, start_pos, exit
- class Maze:
- def __init__(self, maze, maze_idx, maze_with_path, current_position, exit):
- self.maze = maze # Ігрове поле: 0 - порожньо, -1 - старт, -2 - вихід, -4 - стіна
- self.maze_idx = maze_idx
- self.maze_with_path = maze_with_path
- self.current_position = current_position
- self.start = current_position
- self.exit = exit
- def display_maze(self, maze, path=[], arrow_path=[], title='Лабіринт'):
- plt.figure(figsize=(10, 10))
- plt.title(title)
- if len(path) > 0:
- for move, arrow in zip(path, arrow_path):
- plt.text(move[1]+0.15, move[0]+0.75, arrow, color='green', fontsize=10)
- sns.heatmap(maze, cmap='hot', cbar=False, square=True, linewidth=0.5, linecolor='black', annot=False, fmt='')
- plt.show()
- def available_moves(self, pos):
- moves = [(pos[0]-1, pos[1]), (pos[0]+1, pos[1]), (pos[0], pos[1]-1), (pos[0], pos[1]+1)]
- correct_moves = []
- for move in moves:
- if move in self.maze_idx:
- correct_moves.append(move)
- return correct_moves
- def make_move(self, move):
- direction = None
- if move[0] - self.current_position[0] == -1:
- direction = "↑"
- elif move[1] - self.current_position[1] == 1:
- direction = "→"
- elif move[0] - self.current_position[0] == 1:
- direction = "↓"
- elif move[1] - self.current_position[1] == -1:
- direction = "←"
- self.current_position = move
- return direction
- def check_win(self):
- if self.current_position not in (self.exit):
- return 0
- else:
- return 1
- def get_state(self):
- return self.current_position
- def get_reward(self, move):
- if self.maze[move] == 0:
- return -1
- elif self.maze[move] == -4:
- return -1000
- elif self.maze[move] == -2:
- return 0
- elif self.maze[move] == -1:
- return -1
- def train(self, agent, epochs=10, use_low_eps=False):
- agent.rows, agent.cols = len(self.maze[0]), len(self.maze)
- arr = [round(0.1*i*epochs) for i in range(1, 11)]
- len_moves = []
- for epoch in range(1, epochs+1):
- if epoch in arr:
- if use_low_eps == True:
- agent.epsilon -= agent.epsilon*0.175
- path, arrow_path = [], []
- self.current_position = self.start # Скидання ігрового поля
- while True:
- state = self.current_position
- path.append(state)
- action = agent.choose_move(self)
- direction = self.make_move(action)
- arrow_path.append(direction)
- next_state = self.current_position
- reward = self.get_reward(action)
- agent.update_Q_values(self, state, action, reward, next_state)
- if self.check_win() == 1:
- len_moves.append(len(path))
- if epoch % round(epochs*0.1) == 0:
- self.display_maze(self.maze, path=path, arrow_path=arrow_path, title=fr'epoch {epoch} - {next_state}')
- break
- return np.arange(1, epochs+1), len_moves
- def play(self, agent, len_sh_path=86):
- agent.rows, agent.cols, agent.epsilon = len(self.maze[0]), len(self.maze), 0.0
- path, arrow_path = [], []
- self.current_position = self.start # Скидання ігрового поля
- while True:
- state = self.current_position
- path.append(state)
- action = agent.choose_move(self)
- direction = self.make_move(action)
- arrow_path.append(direction)
- next_state = self.current_position
- reward = self.get_reward(action)
- if len(path) > 999:
- self.display_maze(self.maze, path=path, arrow_path=arrow_path, title=fr'Вихід не знайдено(Кількість ходів - {len(path)})')
- break
- if self.check_win() == 1:
- self.display_maze(self.maze, path=path, arrow_path=arrow_path, title=fr'Кількість ходів - {len(path)}(min={len_sh_path})')
- break
- class QLearningAgent:
- def __init__(self, learning_rate=10**-4, discount_factor=0.9, epsilon=0.5):
- self.learning_rate = learning_rate
- self.discount_factor = discount_factor
- self.epsilon = epsilon
- self.Q_values = {}
- self.rows = None
- self.cols = None
- def choose_move(self, game):
- # Вибір дії для наступного кроку
- state = game.get_state()
- move = game.available_moves(state)
- if np.random.uniform(0, 1) < self.epsilon:
- return move[np.random.randint(0, len(move))]
- else:
- # Вибір дії згідно з Q-таблицею
- if state not in self.Q_values:
- # Якщо стан не існує у Q-таблиці, повертаємо випадкову дію
- return move[np.random.randint(0, len(move))]
- else:
- # Вибираємо найкращу дію, враховуючи Q-значення
- q_values = self.Q_values[state]
- max_q_value = np.max(q_values)
- zero_indices = np.where(q_values == max_q_value)
- best_moves = [(i, j) for i, j in zip(zero_indices[0], zero_indices[1])]
- return best_moves[np.random.randint(0, len(best_moves))]
- def update_Q_values(self, game, state, action, reward, next_state):
- # Оновлення значень Q-таблиці за допомогою формули Q-learning
- if state not in self.Q_values:
- maze_with_inf = np.full((self.rows, self.cols), -np.inf)
- for i in game.available_moves(state):
- maze_with_inf[i] = 0.0
- self.Q_values[state] = maze_with_inf
- if next_state not in self.Q_values:
- maze_with_inf = np.full((self.rows, self.cols), -np.inf)
- for i in game.available_moves(next_state):
- maze_with_inf[i] = 0.0
- self.Q_values[next_state] = maze_with_inf
- # Обчислення нового Q-значення за формулою Q-learning
- old_q_value = self.Q_values[state][action]
- max_next_q_value = np.max(self.Q_values[next_state]) if next_state in self.Q_values else 0
- new_q_value = old_q_value + self.learning_rate * (reward + self.discount_factor * max_next_q_value - old_q_value)
- # Оновлення Q-значення в таблиці
- self.Q_values[state][action] = new_q_value
- def create_small_maze():
- maze = np.zeros((9, 9), dtype=float)
- maze_idx = [i for i, j in np.ndenumerate(np.zeros((9, 9), dtype=int))]
- wall = -4
- maze[0, :] = wall
- maze[-1, :] = wall
- maze[:, 0] = wall
- maze[:, -1] = wall
- idx = [
- (2, 2), (3, 2), (4, 4), (5, 6), (7, 6), (7, 2), (3, 5), (2, 6), (3, 7), (3, 6), (3, 4), (5, 2), (6, 2),
- (6, 4), (7, 4), (2, 4)
- ]
- for i in idx: maze[i] = wall
- # Додавання виходу
- exit = [(5, 8)]
- maze[exit[0]] = -2
- # Додавання стартової позиції
- start_pos = (1, 0)
- maze[start_pos] = -1
- maze2 = np.copy(maze)
- return maze, maze_idx, maze2, start_pos, exit
- def display_results(arange, len_moves):
- plt.figure(figsize=(15, 6))
- plt.title(f'Залежність кількості ходів від кількості епізодів навчання')
- plt.plot(arange, len_moves)
- plt.xlabel(f'Номер епізоду')
- plt.ylabel(f'Кількість ходів')
- plt.grid()
- plt.show()
- maze_small, maze_idx_small, maze_with_path_small, current_position_small, exit_small = create_small_maze()
- game_small = Maze(maze_small, maze_idx_small, maze_with_path_small, current_position_small, exit_small)
- game_small.display_maze(game_small.maze)
- agent_small = QLearningAgent(epsilon=0.2, learning_rate=10**-2)
- arange_small, len_moves_small = game_small.train(agent=agent_small, epochs=10000, use_low_eps=False)
- game_small.play(agent_small, len_sh_path=14)
- display_results(arange_small, len_moves_small)
- maze, maze_idx, maze_with_path, current_position, exit = create_maze()
- game = Maze(maze, maze_idx, maze_with_path, current_position, exit)
- game.display_maze(game.maze)
- game.display_maze(game.maze_with_path)
- agent = QLearningAgent(epsilon=0.2, learning_rate=10**-2)
- arange, len_moves = game.train(agent, epochs=25000, use_low_eps=True)
- game.play(agent)
- display_results(arange, len_moves)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement