Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from collections import defaultdict
- import random
- import statistics
- import gymnasium as gym
- import matplotlib.pyplot as plt
- env = gym.make("CartPole-v1")
- discount = 0.9
- turns = 1_000_000
- actions = (0, 1)
- # discretized state -> learned q-value
- q_table = defaultdict(lambda: 0.0)
- # discretized state -> number of times state has been seen
- state_count = defaultdict(lambda: 0)
- def greediness(count):
- """
- Given the number of times a state has been seen, return the probability of choosing
- the action with the highest q-value (as opposed to choosing a random action).
- The greediness should increase as the agent gains more experience with a state (and
- hence has less reason to explore). It should be a number between 0.0 and 1.0.
- """
- ...
- def learning_rate(count):
- """
- Return the learning rate as a function of the number of times a state has been seen.
- The learning rate should decrease with experience, so that what the agent has learned
- is effectively 'locked in' and not forgotten.
- """
- ...
- def discretize(state):
- """
- Take a state in the continuous state space and return a discrete version of that state.
- Discretization should group states into a relatively small number of "buckets", so that
- the agent can generalize from past experiences.
- """
- ...
- def choose_action(state):
- """
- Given a state choose an action, weighing up the need to explore with the need to exploit
- what the agent has already learned for optimal behaviour.
- The agent should choose what it believes to be the 'optimal' action with probability
- `greediness(state)`, and otherwise choose a random action.
- """
- ...
- def update_tables(state, action, new_state, reward):
- """
- Update the `q_table` and `state_count` tables, based on the observed transition.
- """
- ...
- episode_lengths = []
- episode_length = 0
- state, info = env.reset()
- for turn in range(turns):
- action = choose_action(state)
- old_state = state
- state, reward, terminated, truncated, info = env.step(action)
- if terminated:
- reward = -100
- elif truncated:
- reward = +1
- else:
- reward = +1
- update_tables(old_state, action, state, reward)
- episode_length += 1
- if terminated or truncated:
- episode_lengths.append(episode_length)
- episode_length = 0
- state, info = env.reset()
- moving_average = [
- statistics.mean(episode_lengths[i:i+100])
- for i in range(len(episode_lengths)-100)
- ]
- plt.plot(moving_average)
- plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement