Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <iostream>
- #include <vector>
- #include <cmath>
- #include <random>
- #include <algorithm>
- #include "neural_network.cpp"
- // #include "linear_regression.cpp"
- class CartPoleEnv {
- public:
- CartPoleEnv()
- : gravity(9.8), massCart(1.0), massPole(0.1), length(0.5),
- forceMag(10.0), tau(0.02), thetaThresholdRadians(12 * 2 * M_PI / 360),
- xThreshold(2.4), totalMass(massCart + massPole), polemassLength(massPole * length),
- state{0.0, 0.0, 0.0, 0.0}, done(false) {}
- void reset() {
- std::random_device rd;
- std::mt19937 gen(rd());
- std::uniform_real_distribution<> dis(-0.05, 0.05);
- state[0] = dis(gen); // cart position
- state[1] = dis(gen); // cart velocity
- state[2] = dis(gen); // pole angle
- state[3] = dis(gen); // pole angular velocity
- done = false;
- }
- void step(int action) {
- double x = state[0];
- double x_dot = state[1];
- double theta = state[2];
- double theta_dot = state[3];
- double force = (action == 1) ? forceMag : -forceMag;
- double costheta = cos(theta);
- double sintheta = sin(theta);
- double temp = (force + polemassLength * theta_dot * theta_dot * sintheta) / totalMass;
- double theta_acc = (gravity * sintheta - costheta * temp) /
- (length * (4.0 / 3.0 - massPole * costheta * costheta / totalMass));
- double x_acc = temp - polemassLength * theta_acc * costheta / totalMass;
- // Update state
- state[0] += tau * x_dot;
- state[1] += tau * x_acc;
- state[2] += tau * theta_dot;
- state[3] += tau * theta_acc;
- // Check termination
- done = (x < -xThreshold || x > xThreshold || theta < -thetaThresholdRadians || theta > thetaThresholdRadians);
- }
- bool isDone() const {
- return done;
- }
- std::vector<double> getState() const {
- return {state[0], state[1], state[2], state[3]};
- }
- double getReward() const {
- return (std::abs(state[0]) < xThreshold && std::abs(state[2]) < thetaThresholdRadians) ? 1.0 : 0.0;
- }
- private:
- const double gravity;
- const double massCart;
- const double massPole;
- const double length; // actually half the pole's length
- const double forceMag;
- const double tau; // seconds between state updates
- const double thetaThresholdRadians;
- const double xThreshold;
- const double totalMass;
- const double polemassLength;
- double state[4];
- bool done;
- };
- // Compute the advantage using reward-to-go method
- double computeAdvantage(const std::vector<double>& rewards, int t, double gamma) {
- double advantage = 0.0;
- double discount = 1.0;
- for (int i = t; i < rewards.size(); ++i) {
- advantage += discount * rewards[i];
- discount *= gamma;
- }
- return advantage;
- }
- // Compute the loss using negative log likelihood
- double computeLoss(const std::vector<double>& logProbs, const std::vector<double>& advantages) {
- double loss = 0.0;
- for (size_t i = 0; i < logProbs.size(); ++i) {
- loss -= logProbs[i] * advantages[i];
- }
- return loss;
- }
- // Gradient ascent to update parameters
- void gradientAscent(std::vector<double>& params, const std::vector<double>& gradients, double learningRate) {
- std::transform(params.begin(), params.end(), gradients.begin(), params.begin(), [learningRate](double p, double g) {
- return p + learningRate * g;
- });
- }
- // Main training loop for policy gradient
- void trainCartPolePolicyGradient(NeuralNetwork& actor, NeuralNetwork& critic, CartPoleEnv& env, int numEpisodes, double gamma, double learningRate, double GRADIENT_CLIP_THRESHOLD, double weight_decay) {
- AdamWOptimizer actorOptimizer(learningRate, 0.9, 0.999, 0.01, weight_decay);
- AdamWOptimizer criticOptimizer(learningRate, 0.9, 0.999, 0.01, weight_decay);
- // Initialize the actor and critic networks
- actor.add_layer(Layer(4, 32, "relu", actorOptimizer));
- actor.add_layer(Layer(32, 16, "relu", actorOptimizer));
- actor.add_layer(Layer(16, 2, "linear", actorOptimizer));
- critic.add_layer(Layer(4, 16, "relu", criticOptimizer));
- critic.add_layer(Layer(16, 16, "relu", criticOptimizer));
- critic.add_layer(Layer(16, 1, "linear", criticOptimizer)); // Single output for state value
- for (int episode = 0; episode < numEpisodes; ++episode) {
- std::vector<std::vector<double>> states;
- std::vector<double> actions, rewards, logProbs, values;
- env.reset();
- while (!env.isDone()) {
- std::vector<double> state = env.getState();
- states.push_back(state);
- // Get action probabilities from the actor network
- std::vector<std::vector<double>> actionProbs = actor.forward({state});
- // Get the value estimate from the critic network
- std::vector<std::vector<double>> valueEstimates = critic.forward({state});
- values.push_back(valueEstimates[0][0]);
- // Sample an action based on the probabilities
- int action = (actionProbs[0][0] > actionProbs[0][1]) ? 0 : 1;
- logProbs.push_back(std::log(std::max(actionProbs[0][action], 1e-8)));
- // Take the action in the environment
- env.step(action);
- // Store the reward
- rewards.push_back(env.getReward());
- }
- // Compute the advantages using the critic network
- std::vector<double> advantages;
- for (int t = 0; t < rewards.size(); ++t) {
- double td_target = rewards[t] + (t < rewards.size() - 1 ? gamma * values[t + 1] : 0.0);
- advantages.push_back(td_target - values[t]);
- }
- // Compute the policy (actor) loss
- double actorLoss = computeLoss(logProbs, advantages);
- std::cout << "Episode " << episode << ", Actor Loss: " << actorLoss << std::endl;
- // Compute the critic loss (mean squared error)
- double criticLoss = 0.0;
- for (size_t i = 0; i < rewards.size(); ++i) {
- double td_target = rewards[i] + (i < rewards.size() - 1 ? gamma * values[i + 1] : 0.0);
- criticLoss += pow(td_target - values[i], 2);
- }
- criticLoss /= rewards.size();
- std::cout << "Episode " << episode << ", Critic Loss: " << criticLoss << std::endl;
- // Backpropagate and update actor network
- actor.backward({{actorLoss}}, GRADIENT_CLIP_THRESHOLD);
- actor.update_weights();
- // Backpropagate and update critic network
- critic.backward({{criticLoss}}, GRADIENT_CLIP_THRESHOLD);
- critic.update_weights();
- }
- }
- // Main function to run the training
- int main() {
- CartPoleEnv env;
- NeuralNetwork actor;
- NeuralNetwork critic_1;
- NeuralNetwork critic_2;
- NeuralNetwork target;
- trainCartPolePolicyGradient(actor, critic_1, env, 500, 0.99, 0.001, 0.05, 1e-4);
- //trainCartPoleSAC(actor, critic_1, critic_2, target, env, 500, 0.99, 0.0075, 0.15, 0.001, 0.05, 1e-4, 1e5);
- return 0;
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement