Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import gymnasium as gym
- import random
- from collections import defaultdict
- import matplotlib.pyplot as plt
- import statistics
- from tqdm import tqdm
- env = gym.make("CartPole-v1")
- q_table = defaultdict(lambda: 0.0)
- counts = defaultdict(lambda: 0)
- discount = 0.9
- turns = 1_000_000
- episode_lengths = []
- actions = (0, 1)
- state, info = env.reset()
- def greediness(count):
- return 1.0 - 20/(20+count)
- def learning_rate(count):
- return 30/(30+count)
- def discretize(state):
- return (
- int( 2.0 * state[0]),
- int( 2.0 * state[1]),
- int(20.0 * state[2]),
- int( 2.0 * state[3]),
- )
- def choose_action(state):
- state = discretize(state)
- if random.random() < greediness(counts[state]):
- return max(actions, key=lambda action: q_table[state, action])
- else:
- return random.choice(actions)
- def update_table(state, action, new_state, reward):
- state = discretize(state)
- new_state = discretize(new_state)
- q_table[state, action] += learning_rate(counts[state]) * (
- + reward
- + discount * max(q_table[new_state, action] for action in actions)
- - q_table[state, action]
- )
- counts[state] += 1
- episode_length = 0
- for turn in tqdm(range(turns)):
- action = choose_action(state)
- episode_length += 1
- old_state = state
- state, reward, terminated, truncated, info = env.step(action)
- if terminated:
- reward = -100
- elif truncated:
- reward = +1
- else:
- reward = +1
- update_table(old_state, action, state, reward)
- if terminated or truncated:
- state, info = env.reset()
- episode_lengths.append(episode_length)
- episode_length = 0
- moving_average = [
- statistics.mean(episode_lengths[i:i+100])
- for i in range(len(episode_lengths)-100)
- ]
- plt.plot(moving_average)
- plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement