Advertisement
alkkofficial

Untitled

Aug 2nd, 2024
40
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.59 KB | None | 0 0
  1.  
  2. from collections import defaultdict
  3. import random
  4. import statistics
  5.  
  6. import gymnasium as gym
  7. import matplotlib.pyplot as plt
  8.  
  9. env = gym.make("CartPole-v1")
  10.  
  11. discount = 0.9
  12. turns = 1_000_000
  13. actions = (0, 1)
  14.  
  15. # discretized state -> learned q-value
  16. q_table = defaultdict(lambda: 0.0)
  17.  
  18. # discretized state -> number of times state has been seen
  19. state_count = defaultdict(lambda: 0)
  20.  
  21. def greediness(count):
  22. """
  23. Given the number of times a state has been seen, return the probability of choosing
  24. the action with the highest q-value (as opposed to choosing a random action).
  25.  
  26. The greediness should increase as the agent gains more experience with a state (and
  27. hence has less reason to explore). It should be a number between 0.0 and 1.0.
  28. """
  29. ...
  30.  
  31. def learning_rate(count):
  32. """
  33. Return the learning rate as a function of the number of times a state has been seen.
  34.  
  35. The learning rate should decrease with experience, so that what the agent has learned
  36. is effectively 'locked in' and not forgotten.
  37. """
  38. ...
  39.  
  40. def discretize(state):
  41. """
  42. Take a state in the continuous state space and return a discrete version of that state.
  43.  
  44. Discretization should group states into a relatively small number of "buckets", so that
  45. the agent can generalize from past experiences.
  46. """
  47. ...
  48.  
  49. def choose_action(state):
  50. """
  51. Given a state choose an action, weighing up the need to explore with the need to exploit
  52. what the agent has already learned for optimal behaviour.
  53.  
  54. The agent should choose what it believes to be the 'optimal' action with probability
  55. `greediness(state)`, and otherwise choose a random action.
  56. """
  57. ...
  58.  
  59. def update_tables(state, action, new_state, reward):
  60. """
  61. Update the `q_table` and `state_count` tables, based on the observed transition.
  62. """
  63. ...
  64.  
  65. episode_lengths = []
  66.  
  67. episode_length = 0
  68. state, info = env.reset()
  69.  
  70. for turn in range(turns):
  71. action = choose_action(state)
  72. old_state = state
  73. state, reward, terminated, truncated, info = env.step(action)
  74.  
  75. if terminated:
  76. reward = -100
  77. elif truncated:
  78. reward = +1
  79. else:
  80. reward = +1
  81.  
  82. update_tables(old_state, action, state, reward)
  83. episode_length += 1
  84.  
  85. if terminated or truncated:
  86. episode_lengths.append(episode_length)
  87. episode_length = 0
  88. state, info = env.reset()
  89.  
  90. moving_average = [
  91. statistics.mean(episode_lengths[i:i+100])
  92. for i in range(len(episode_lengths)-100)
  93. ]
  94. plt.plot(moving_average)
  95. plt.show()
  96.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement