Advertisement
exotic666

pong_dqn

Dec 10th, 2024
41
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.51 KB | Source Code | 0 0
  1. # Import necessary libraries
  2. import gym
  3. import torch
  4. import torch.nn as nn
  5. import torch.optim as optim
  6. import numpy as np
  7. import cv2
  8. import random
  9. import os
  10.  
  11. # Config class
  12. class Config:
  13.     num_episodes = 1800
  14.     max_steps_per_episode = 1000
  15.     buffer_capacity = 50000
  16.     batch_size = 32
  17.     gamma = 0.99
  18.     lr = 1e-4
  19.     epsilon_start = 1.0
  20.     epsilon_min = 0.01
  21.     epsilon_decay = 0.995
  22.     target_update_freq = 2000
  23.     save_path = "/content/drive/MyDrive/checkpoints"
  24.  
  25. # DQN model
  26. class DQN(nn.Module):
  27.     def __init__(self, input_channels, n_actions):
  28.         super(DQN, self).__init__()
  29.         self.conv = nn.Sequential(
  30.             nn.Conv2d(input_channels, 32, kernel_size=8, stride=4), nn.ReLU(),
  31.             nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(),
  32.             nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU()
  33.         )
  34.         self.fc = nn.Sequential(
  35.             nn.Linear(64 * 7 * 7, 512), nn.ReLU(),
  36.             nn.Linear(512, n_actions)
  37.         )
  38.  
  39.     def forward(self, x):
  40.         x = x / 255.0  # Normalize
  41.         x = self.conv(x)
  42.         x = x.view(x.size(0), -1)
  43.         return self.fc(x)
  44.  
  45. # Replay buffer
  46. class ReplayBuffer:
  47.     def __init__(self, capacity):
  48.         self.buffer = []
  49.         self.capacity = capacity
  50.  
  51.     def push(self, *experience):
  52.         if len(self.buffer) >= self.capacity:
  53.             self.buffer.pop(0)
  54.         self.buffer.append(experience)
  55.  
  56.     def sample(self, batch_size):
  57.         samples = random.sample(self.buffer, batch_size)
  58.         return map(np.array, zip(*samples))
  59.  
  60.     def __len__(self):
  61.         return len(self.buffer)
  62.  
  63. # Frame Preprocessing
  64. def preprocess_frame(frame):
  65.     gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
  66.     resized = cv2.resize(gray, (84, 84), interpolation=cv2.INTER_AREA)
  67.     return resized
  68.  
  69. # Frame Stacking
  70. def stack_frames(frames, frame, is_new_episode=False):
  71.     if is_new_episode:
  72.         frames = [frame] * 4
  73.     else:
  74.         frames = frames[1:] + [frame]
  75.     return np.stack(frames, axis=0), frames
  76.  
  77. # Action Selection
  78. def select_action(state, policy_net, epsilon, n_actions, device):
  79.     if np.random.rand() < epsilon:
  80.         return np.random.randint(n_actions)
  81.     else:
  82.         state = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)
  83.         with torch.no_grad():
  84.             return policy_net(state).argmax().item()
  85.  
  86. # Save Model
  87. def save_model(policy_net, filename):
  88.     os.makedirs(Config.save_path, exist_ok=True)
  89.     torch.save(policy_net.state_dict(), filename)
  90.  
  91. # Load Model
  92. def load_model(policy_net, filename):
  93.     if os.path.exists(filename):
  94.         policy_net.load_state_dict(torch.load(filename))
  95.         print(f"Model loaded from {filename}")
  96.  
  97. # Training Function
  98. def train():
  99.     env = gym.make("PongNoFrameskip-v4")
  100.     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  101.     n_actions = env.action_space.n
  102.  
  103.     policy_net = DQN(4, n_actions).to(device)
  104.     target_net = DQN(4, n_actions).to(device)
  105.     target_net.load_state_dict(policy_net.state_dict())
  106.     target_net.eval()
  107.  
  108.     optimizer = optim.Adam(policy_net.parameters(), lr=Config.lr)
  109.     scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.9)
  110.  
  111.     buffer = ReplayBuffer(Config.buffer_capacity)
  112.     epsilon = Config.epsilon_start
  113.  
  114.     for episode in range(Config.num_episodes):
  115.         state = preprocess_frame(env.reset())
  116.         frames = [state] * 4
  117.         state, frames = stack_frames([], state, is_new_episode=True)
  118.         total_reward, done = 0, False
  119.  
  120.         for _ in range(Config.max_steps_per_episode):
  121.             action = select_action(state, policy_net, epsilon, n_actions, device)
  122.             next_frame, reward, done, *_ = env.step(action)
  123.             next_frame = preprocess_frame(next_frame)
  124.             next_state, frames = stack_frames(frames, next_frame, is_new_episode=False)
  125.  
  126.             buffer.push(state, action, reward, next_state, done)
  127.             state = next_state
  128.             total_reward += reward
  129.  
  130.             if len(buffer) >= Config.batch_size:
  131.                 states, actions, rewards, next_states, dones = buffer.sample(Config.batch_size)
  132.                 states = torch.tensor(states, dtype=torch.float32).to(device)
  133.                 actions = torch.tensor(actions, dtype=torch.long).to(device)
  134.                 rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
  135.                 next_states = torch.tensor(next_states, dtype=torch.float32).to(device)
  136.                 dones = torch.tensor(dones, dtype=torch.float32).to(device)
  137.  
  138.                 q_values = policy_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)
  139.                 next_q_values = target_net(next_states).max(1)[0]
  140.                 targets = rewards + Config.gamma * next_q_values * (1 - dones)
  141.  
  142.                 loss = nn.SmoothL1Loss()(q_values, targets)
  143.                 optimizer.zero_grad()
  144.                 loss.backward()
  145.                 optimizer.step()
  146.                 scheduler.step()
  147.  
  148.             if done:
  149.                 break
  150.  
  151.         epsilon = max(Config.epsilon_min, epsilon * Config.epsilon_decay)
  152.  
  153.         if episode % Config.target_update_freq == 0:
  154.             target_net.load_state_dict(policy_net.state_dict())
  155.  
  156.         if episode % 100 == 0:
  157.             save_model(policy_net, f"{Config.save_path}/model_episode_{episode}.pth")
  158.  
  159.         print(f"Episode {episode}, Total Reward: {total_reward}")
  160.  
  161.     save_model(policy_net, f"{Config.save_path}/final_model.pth")
  162.     print("Training complete. Model saved.")
  163.     env.close()
  164.  
  165. # Testing Function
  166. def test():
  167.     env = gym.make("PongNoFrameskip-v4")
  168.     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  169.     n_actions = env.action_space.n
  170.  
  171.     policy_net = DQN(4, n_actions).to(device)
  172.     load_model(policy_net, f"{Config.save_path}/final_model.pth")
  173.     policy_net.eval()
  174.  
  175.     state = preprocess_frame(env.reset())
  176.     frames = [state] * 4
  177.     state, frames = stack_frames([], state, is_new_episode=True)
  178.     total_reward, done = 0, False
  179.  
  180.     while not done:
  181.         env.render()
  182.         action = select_action(state, policy_net, 0, n_actions, device)
  183.         next_frame, reward, done, *_ = env.step(action)
  184.         next_frame = preprocess_frame(next_frame)
  185.         state, frames = stack_frames(frames, next_frame, is_new_episode=False)
  186.         total_reward += reward
  187.  
  188.     print(f"Test complete. Total Reward: {total_reward}")
  189.     env.close()
  190.  
  191. # Main Execution
  192. if __name__ == "__main__":
  193.     train()
  194.     test()
  195.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement