Advertisement
alkkofficial

Untitled

Aug 6th, 2024
48
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.86 KB | None | 0 0
  1. import gymnasium as gym
  2. import random
  3. from collections import defaultdict
  4. import matplotlib.pyplot as plt
  5. import statistics
  6. from tqdm import tqdm
  7.  
  8. env = gym.make("CartPole-v1")
  9. q_table = defaultdict(lambda: 0.0)
  10. counts = defaultdict(lambda: 0)
  11. discount = 0.9
  12. turns = 1_000_000
  13. episode_lengths = []
  14.  
  15. actions = (0, 1)
  16.  
  17. state, info = env.reset()
  18.  
  19. def greediness(count):
  20. return 1.0 - 20/(20+count)
  21.  
  22. def learning_rate(count):
  23. return 30/(30+count)
  24.  
  25. def discretize(state):
  26. return (
  27. int( 2.0 * state[0]),
  28. int( 2.0 * state[1]),
  29. int(20.0 * state[2]),
  30. int( 2.0 * state[3]),
  31. )
  32.  
  33. def choose_action(state):
  34. state = discretize(state)
  35. if random.random() < greediness(counts[state]):
  36. return max(actions, key=lambda action: q_table[state, action])
  37. else:
  38. return random.choice(actions)
  39.  
  40. def update_table(state, action, new_state, reward):
  41. state = discretize(state)
  42. new_state = discretize(new_state)
  43. q_table[state, action] += learning_rate(counts[state]) * (
  44. + reward
  45. + discount * max(q_table[new_state, action] for action in actions)
  46. - q_table[state, action]
  47. )
  48. counts[state] += 1
  49.  
  50. episode_length = 0
  51.  
  52. for turn in tqdm(range(turns)):
  53. action = choose_action(state)
  54. episode_length += 1
  55. old_state = state
  56. state, reward, terminated, truncated, info = env.step(action)
  57.  
  58. if terminated:
  59. reward = -100
  60. elif truncated:
  61. reward = +1
  62. else:
  63. reward = +1
  64. update_table(old_state, action, state, reward)
  65.  
  66. if terminated or truncated:
  67. state, info = env.reset()
  68. episode_lengths.append(episode_length)
  69. episode_length = 0
  70.  
  71. moving_average = [
  72. statistics.mean(episode_lengths[i:i+100])
  73. for i in range(len(episode_lengths)-100)
  74. ]
  75. plt.plot(moving_average)
  76. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement