Advertisement
Mochinov

Untitled

Jun 10th, 2023 (edited)
182
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.83 KB | None | 0 0
  1. import pygame
  2. import numpy as np
  3. import tensorflow as tf
  4.  
  5. # Инициализация Pygame
  6. pygame.init()
  7. screen_width = 800
  8. screen_height = 600
  9. screen = pygame.display.set_mode((screen_width, screen_height))
  10. clock = pygame.time.Clock()
  11.  
  12. # Создание модели искусственного интеллекта
  13. model = tf.keras.models.Sequential([
  14. tf.keras.layers.Dense(32, activation='relu', input_shape=(2,)),
  15. tf.keras.layers.Dense(32, activation='relu'),
  16. tf.keras.layers.Dense(2, activation='softmax')
  17. ])
  18. model.compile(optimizer='adam', loss='categorical_crossentropy')
  19.  
  20. # Задание трассы
  21. track_points = [(100, 200), (300, 100), (500, 300), (700, 200)]
  22.  
  23. # Начальное положение и скорость машинки
  24. car_pos = np.array([track_points[0][0], track_points[0][1]])
  25. car_speed = np.array([5, 0])
  26.  
  27. # Обучение модели
  28. learning_rate = 0.001
  29. discount_factor = 0.99
  30. epsilon = 0.1
  31.  
  32. for episode in range(1000):
  33. state = car_pos.copy()
  34. total_reward = 0
  35.  
  36. while True:
  37. # Выбор действия на основе epsilon-greedy
  38. if np.random.rand() < epsilon:
  39. action = np.random.randint(2)
  40. else:
  41. q_values = model.predict(state.reshape(1, 2))
  42. action = np.argmax(q_values[0])
  43.  
  44. # Исполнение действия
  45. if action == 0:
  46. car_speed[1] -= 0.5
  47. else:
  48. car_speed[1] += 0.5
  49.  
  50. # Ограничение скорости машинки
  51. car_speed[1] = np.clip(car_speed[1], -5, 5)
  52.  
  53. # Обновление положения машинки
  54. car_pos += car_speed
  55.  
  56. # Расчет вознаграждения
  57. if car_pos[0] > track_points[-1][0]:
  58. reward = 1
  59. else:
  60. reward = 0
  61.  
  62. # Обновление модели с помощью обучения с подкреплением
  63. target = reward + discount_factor * np.max(model.predict(car_pos.reshape(1, 2))[0])
  64. q_values_target = model.predict(state.reshape(1, 2))
  65. q_values_target[0][action] = target
  66. model.fit(state.reshape(1, 2), q_values_target, epochs=1, verbose=0)
  67.  
  68. # Обновление состояния и суммарного вознаграждения
  69. state = car_pos.copy()
  70. total_reward += reward
  71.  
  72. # Отрисовка игры
  73. screen.fill((255, 255, 255))
  74. pygame.draw.lines(screen, (0, 0, 0), False, track_points, 5)
  75. pygame.draw.circle(screen, (255, 0, 0), car_pos.astype(int), 10)
  76. pygame.display.flip()
  77. clock.tick(60)
  78.  
  79. if car_pos[0] > screen_width:
  80. break
  81.  
  82. print("Episode:", episode, "Total Reward:", total_reward)
  83.  
  84. pygame.quit()
  85.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement