Advertisement
Trainlover08

cart-pole.cpp

Aug 19th, 2024 (edited)
82
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 7.08 KB | None | 0 0
  1. #include <iostream>
  2. #include <vector>
  3. #include <cmath>
  4. #include <random>
  5. #include <algorithm>
  6.  
  7. #include "neural_network.cpp"
  8. // #include "linear_regression.cpp"
  9.  
  10. class CartPoleEnv {
  11. public:
  12.     CartPoleEnv()
  13.         : gravity(9.8), massCart(1.0), massPole(0.1), length(0.5),
  14.           forceMag(10.0), tau(0.02), thetaThresholdRadians(12 * 2 * M_PI / 360),
  15.           xThreshold(2.4), totalMass(massCart + massPole), polemassLength(massPole * length),
  16.           state{0.0, 0.0, 0.0, 0.0}, done(false) {}
  17.  
  18.     void reset() {
  19.         std::random_device rd;
  20.         std::mt19937 gen(rd());
  21.         std::uniform_real_distribution<> dis(-0.05, 0.05);
  22.  
  23.         state[0] = dis(gen); // cart position
  24.         state[1] = dis(gen); // cart velocity
  25.         state[2] = dis(gen); // pole angle
  26.         state[3] = dis(gen); // pole angular velocity
  27.  
  28.         done = false;
  29.     }
  30.  
  31.     void step(int action) {
  32.         double x = state[0];
  33.         double x_dot = state[1];
  34.         double theta = state[2];
  35.         double theta_dot = state[3];
  36.  
  37.         double force = (action == 1) ? forceMag : -forceMag;
  38.         double costheta = cos(theta);
  39.         double sintheta = sin(theta);
  40.  
  41.         double temp = (force + polemassLength * theta_dot * theta_dot * sintheta) / totalMass;
  42.         double theta_acc = (gravity * sintheta - costheta * temp) /
  43.                            (length * (4.0 / 3.0 - massPole * costheta * costheta / totalMass));
  44.         double x_acc = temp - polemassLength * theta_acc * costheta / totalMass;
  45.  
  46.         // Update state
  47.         state[0] += tau * x_dot;
  48.         state[1] += tau * x_acc;
  49.         state[2] += tau * theta_dot;
  50.         state[3] += tau * theta_acc;
  51.  
  52.         // Check termination
  53.         done = (x < -xThreshold || x > xThreshold || theta < -thetaThresholdRadians || theta > thetaThresholdRadians);
  54.     }
  55.  
  56.     bool isDone() const {
  57.         return done;
  58.     }
  59.  
  60.     std::vector<double> getState() const {
  61.         return {state[0], state[1], state[2], state[3]};
  62.     }
  63.  
  64.     double getReward() const {
  65.         return (std::abs(state[0]) < xThreshold && std::abs(state[2]) < thetaThresholdRadians) ? 1.0 : 0.0;
  66.     }
  67.  
  68. private:
  69.     const double gravity;
  70.     const double massCart;
  71.     const double massPole;
  72.     const double length;  // actually half the pole's length
  73.     const double forceMag;
  74.     const double tau;  // seconds between state updates
  75.     const double thetaThresholdRadians;
  76.     const double xThreshold;
  77.     const double totalMass;
  78.     const double polemassLength;
  79.  
  80.     double state[4];
  81.     bool done;
  82. };
  83.  
  84. // Compute the advantage using reward-to-go method
  85. double computeAdvantage(const std::vector<double>& rewards, int t, double gamma) {
  86.     double advantage = 0.0;
  87.     double discount = 1.0;
  88.     for (int i = t; i < rewards.size(); ++i) {
  89.         advantage += discount * rewards[i];
  90.         discount *= gamma;
  91.     }
  92.     return advantage;
  93. }
  94.  
  95. // Compute the loss using negative log likelihood
  96. double computeLoss(const std::vector<double>& logProbs, const std::vector<double>& advantages) {
  97.     double loss = 0.0;
  98.     for (size_t i = 0; i < logProbs.size(); ++i) {
  99.         loss -= logProbs[i] * advantages[i];
  100.     }
  101.     return loss;
  102. }
  103.  
  104. // Gradient ascent to update parameters
  105. void gradientAscent(std::vector<double>& params, const std::vector<double>& gradients, double learningRate) {
  106.     std::transform(params.begin(), params.end(), gradients.begin(), params.begin(), [learningRate](double p, double g) {
  107.         return p + learningRate * g;
  108.     });
  109. }
  110.  
  111. // Main training loop for policy gradient
  112. void trainCartPolePolicyGradient(NeuralNetwork& actor, NeuralNetwork& critic, CartPoleEnv& env, int numEpisodes, double gamma, double learningRate, double GRADIENT_CLIP_THRESHOLD, double weight_decay) {
  113.     AdamWOptimizer actorOptimizer(learningRate, 0.9, 0.999, 0.01, weight_decay);
  114.     AdamWOptimizer criticOptimizer(learningRate, 0.9, 0.999, 0.01, weight_decay);
  115.  
  116.     // Initialize the actor and critic networks
  117.     actor.add_layer(Layer(4, 32, "relu", actorOptimizer));
  118.     actor.add_layer(Layer(32, 16, "relu", actorOptimizer));
  119.     actor.add_layer(Layer(16, 2, "linear", actorOptimizer));
  120.  
  121.     critic.add_layer(Layer(4, 16, "relu", criticOptimizer));
  122.     critic.add_layer(Layer(16, 16, "relu", criticOptimizer));
  123.     critic.add_layer(Layer(16, 1, "linear", criticOptimizer));  // Single output for state value
  124.  
  125.     for (int episode = 0; episode < numEpisodes; ++episode) {
  126.         std::vector<std::vector<double>> states;
  127.         std::vector<double> actions, rewards, logProbs, values;
  128.  
  129.         env.reset();
  130.         while (!env.isDone()) {
  131.             std::vector<double> state = env.getState();
  132.             states.push_back(state);
  133.  
  134.             // Get action probabilities from the actor network
  135.             std::vector<std::vector<double>> actionProbs = actor.forward({state});
  136.  
  137.             // Get the value estimate from the critic network
  138.             std::vector<std::vector<double>> valueEstimates = critic.forward({state});
  139.             values.push_back(valueEstimates[0][0]);
  140.  
  141.             // Sample an action based on the probabilities
  142.             int action = (actionProbs[0][0] > actionProbs[0][1]) ? 0 : 1;
  143.             logProbs.push_back(std::log(std::max(actionProbs[0][action], 1e-8)));
  144.  
  145.             // Take the action in the environment
  146.             env.step(action);
  147.  
  148.             // Store the reward
  149.             rewards.push_back(env.getReward());
  150.         }
  151.  
  152.         // Compute the advantages using the critic network
  153.         std::vector<double> advantages;
  154.         for (int t = 0; t < rewards.size(); ++t) {
  155.             double td_target = rewards[t] + (t < rewards.size() - 1 ? gamma * values[t + 1] : 0.0);
  156.             advantages.push_back(td_target - values[t]);
  157.         }
  158.  
  159.         // Compute the policy (actor) loss
  160.         double actorLoss = computeLoss(logProbs, advantages);
  161.         std::cout << "Episode " << episode << ", Actor Loss: " << actorLoss << std::endl;
  162.  
  163.         // Compute the critic loss (mean squared error)
  164.         double criticLoss = 0.0;
  165.         for (size_t i = 0; i < rewards.size(); ++i) {
  166.             double td_target = rewards[i] + (i < rewards.size() - 1 ? gamma * values[i + 1] : 0.0);
  167.             criticLoss += pow(td_target - values[i], 2);
  168.         }
  169.         criticLoss /= rewards.size();
  170.         std::cout << "Episode " << episode << ", Critic Loss: " << criticLoss << std::endl;
  171.  
  172.         // Backpropagate and update actor network
  173.         actor.backward({{actorLoss}}, GRADIENT_CLIP_THRESHOLD);
  174.         actor.update_weights();
  175.  
  176.         // Backpropagate and update critic network
  177.         critic.backward({{criticLoss}}, GRADIENT_CLIP_THRESHOLD);
  178.         critic.update_weights();
  179.     }
  180. }
  181.  
  182. // Main function to run the training
  183. int main() {
  184.     CartPoleEnv env;
  185.     NeuralNetwork actor;
  186.     NeuralNetwork critic_1;
  187.     NeuralNetwork critic_2;
  188.     NeuralNetwork target;
  189.  
  190.     trainCartPolePolicyGradient(actor, critic_1, env, 500, 0.99, 0.001, 0.05, 1e-4);
  191.  
  192.     //trainCartPoleSAC(actor, critic_1, critic_2, target, env, 500, 0.99, 0.0075, 0.15, 0.001, 0.05, 1e-4, 1e5);
  193.  
  194.     return 0;
  195. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement