Advertisement
Trainlover08

VEX_AI_SIM/robot_controller/include/robotFunctions.cpp

Oct 29th, 2024
28
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 8.01 KB | None | 0 0
  1. // (VEX_AI_SIM/robot_controller/include/robotFunctions.cpp)
  2.  
  3. // library for sim functions
  4. #include "include/simFunctions.cpp"
  5.  
  6. // create the sim class
  7. Sim sim;
  8.  
  9. class Bot{
  10. private:
  11. double currentScore;
  12. bool training = true;
  13. int currentTerm = 25;
  14. int previousTerm = 0;
  15. bool E_NEW_NET = true;
  16. public:
  17. const void* message = " ";
  18. double functionOutput;
  19.  
  20. // activate or deactivate the hook
  21. void hook(bool activated) {
  22. if (activated == true) {
  23. Hook->setPosition(0.2);
  24. } else {
  25. Hook->setPosition(0);
  26. }
  27. }
  28.  
  29. void topArm(bool activated) {
  30. if (activated == true) {
  31. topRingArm->setPosition(-1.2);
  32. } else {
  33. topRingArm->setPosition(0);
  34. }
  35. }
  36.  
  37. void holder(bool activated) {
  38. if (activated == true) {
  39. intakeHold->setPosition(-0.5);
  40. } else {
  41. intakeHold->setPosition(0);
  42. }
  43. }
  44.  
  45. void clamp(bool activated) {
  46. if (activated == true) {
  47. ringClampL->setPosition(0.3);
  48. ringClampR->setPosition(0.3);
  49. } else {
  50. ringClampL->setPosition(-0.3);
  51. ringClampR->setPosition(-0.3);
  52. }
  53. }
  54.  
  55. void intake(bool spinning) {
  56. if (spinning == true) {
  57. intakeRoll->setVelocity(-100.0);
  58. } else {
  59. intakeRoll->setVelocity(0);
  60. }
  61. }
  62.  
  63. void Conveyor(bool spinning) {
  64. if (spinning == true) {
  65. conveyor->setVelocity(-10.0);
  66. } else {
  67. conveyor->setVelocity(0);
  68. }
  69. }
  70.  
  71. // training functions for network
  72. void trainingNetwork(NeuralNetwork& actor, NeuralNetwork& critic, int numEpisodes, double gamma, double learningRate, double GRADIENT_CLASH_THRESHOLD, double weight_decay) {
  73. // initialize variables
  74. vector<double> state;
  75. double input1;
  76. double input2;
  77. double input3;
  78. int extraReward;
  79. int runCount;
  80.  
  81. AdamWOptimizer actorOptimizer(learningRate, 0.9, 0.999, 0.01, weight_decay);
  82. AdamWOptimizer criticOptimizer(learningRate, 0.9, 0.999, 0.01, weight_decay);
  83.  
  84. actor.add_layer(Layer(3, 128, "relu", actorOptimizer));
  85. actor.add_layer(Layer(128, 128, "relu", actorOptimizer));
  86. actor.add_layer(Layer(128, 1, "linear", actorOptimizer));
  87.  
  88. critic.add_layer(Layer(3, 128, "relu", criticOptimizer));
  89. critic.add_layer(Layer(128, 128, "relu", criticOptimizer));
  90. critic.add_layer(Layer(128, 1, "linear", criticOptimizer));
  91.  
  92. if(E_NEW_NET == 1) {
  93. // Save neural network to file
  94. actor.save("actor_network_params.txt");
  95. critic.save("critic_network_params.txt");
  96. }
  97.  
  98. for (int episode = 0; episode <= numEpisodes; ++episode) {
  99. // Load neural network from file
  100. NeuralNetwork actorLoadedNN;
  101. NeuralNetwork criticLoadedNN;
  102. actorLoadedNN.load("actor_network_params.txt");
  103. criticLoadedNN.load("critic_network_params.txt");
  104.  
  105. actor = actorLoadedNN;
  106. critic = criticLoadedNN;
  107.  
  108. fstream AIshots("ai_rec.html");
  109.  
  110. vector<vector<double>> states;
  111. vector<double> actions, rewards, logProbs, values;
  112.  
  113. if (left1->getVelocity() != 0.0) {
  114. sim.moveBot(0);
  115. sim.delay(50, "msec");
  116. }
  117. sim.resetSimManual();
  118. sim.programSetup();
  119. robot->animationStartRecording("ai_rec.html");
  120. training = true;
  121. while (training == true) {
  122. runCount = runCount + 1;
  123.  
  124. // average velocities, and insert into array
  125. input1 = (left1->getVelocity() + left2->getVelocity() + left3->getVelocity()) / 3;
  126. input2 = (right1->getVelocity() + right2->getVelocity() + right3->getVelocity()) / 3;
  127. input3 = currentScore;
  128.  
  129. // erase the vector, and insert the array
  130. state.assign({input1, input2, input3});
  131. states.push_back(state);
  132.  
  133. vector<vector<double>> actionProbs = actor.forward({state});
  134.  
  135. vector<vector<double>> valueEstimates = critic.forward({state});
  136.  
  137. values.push_back(valueEstimates[0][0]);
  138.  
  139. sim.delay(64, "msec");
  140.  
  141. double action = actionProbs[0][0];
  142. logProbs.push_back(log(max(actionProbs[0][action], 1e-8)));
  143.  
  144. if (action > 0) {
  145. functionOutput = action * 10;
  146. } else if (action < 0) {
  147. functionOutput = action * -10;
  148. }
  149.  
  150. if (functionOutput > 1) {
  151. functionOutput = functionOutput / 30;
  152. }
  153.  
  154. cout << "MAINBOT: functionOutput = " << functionOutput << endl;
  155. functionConvert(functionOutput);
  156.  
  157. if (left1->getVelocity() >= 0.1) {
  158. extraReward = extraReward + 20;
  159. }
  160.  
  161. if (left1->getAcceleration() >= 0.01) {
  162. extraReward = extraReward + 3;
  163. }
  164.  
  165. if (invisVision->getRecognitionNumberOfObjects() > 0) {
  166. extraReward = extraReward + (invisVision->getRecognitionNumberOfObjects() * 3);
  167. }
  168.  
  169. if (left1->getVelocity() >= 0 && right1->getVelocity() <= 0) {
  170. extraReward = extraReward + 15;
  171. } else if (left1->getVelocity() <= 0 && right1->getVelocity() >= 0) {
  172. extraReward = extraReward + 20;
  173. }
  174.  
  175. if (topRingArm->getTargetPosition() > 0) {
  176. extraReward = extraReward + 10;
  177. }
  178.  
  179. if (intakeHold->getTargetPosition() > 0) {
  180. extraReward = extraReward + 10;
  181. }
  182.  
  183. if (yes->getSpeed() <= 0.05) {
  184. extraReward = extraReward / 3;
  185. }
  186.  
  187. sim.receive();
  188. if (receiv->getQueueLength() >= 1) {
  189. message = receiv->getData();
  190. currentScore = *(double *)message;
  191. rewards.push_back(input1 + input2 + input3 + extraReward);
  192. receiv->nextPacket();
  193. }
  194.  
  195. if (robot->getTime() >= currentTerm) {
  196. training = false;
  197. previousTerm = currentTerm;
  198. currentTerm = currentTerm + 25;
  199. extraReward = 0;
  200.  
  201. actor.save("actor_network_params.txt");
  202. critic.save("critic_network_params.txt");
  203. }
  204. }
  205.  
  206. vector<double> advantages;
  207. for (int t = 0; t < rewards.size(); ++t) {
  208. double td_target = rewards[t] + (t < rewards.size() - 1 ? gamma * values[t + 1] : 0.0);
  209. advantages.push_back(td_target - values[t]);
  210. }
  211.  
  212. if (episode == numEpisodes) {
  213. robot->animationStopRecording();
  214. }
  215.  
  216. double actorLoss = computeLoss(logProbs, advantages);
  217.  
  218. double criticLoss = 0.0;
  219. for (size_t i = 0; i < rewards.size(); ++i) {
  220. double td_target = rewards[i] + (i < rewards.size() - 1 ? gamma * values[i + 1] : 0.0);
  221. criticLoss += pow(td_target - values[i], 2);
  222. }
  223. criticLoss = rewards.size();
  224.  
  225. actor.backward({{actorLoss}}, GRADIENT_CLASH_THRESHOLD);
  226. actor.update_weights();
  227.  
  228. critic.backward({{criticLoss}}, GRADIENT_CLASH_THRESHOLD);
  229. critic.update_weights();
  230. }
  231.  
  232. }
  233.  
  234. double computeLoss(const vector<double>& logProbs, const vector<double>& advantages) {
  235. double loss = -0.005;
  236. for (int i = 0; i < logProbs.size(); ++ i) {
  237. loss -= logProbs[i] * advantages[i];
  238. }
  239. return loss;
  240. }
  241.  
  242. void functionConvert(double functionID) {
  243. if (functionID >= -0.1 && functionID <= 0.1) {
  244. sim.moveBot(0);
  245. } else if (functionID >= 0.1 && functionID <= 0.2) {
  246. sim.moveBot(1);
  247. } else if (functionID >= 0.2 && functionID <= 0.3) {
  248. sim.moveBot(2);
  249. } else if (functionID >= 0.3 && functionID <= 0.4) {
  250. sim.moveBot(3);
  251. } else if (functionID >= 0.4 && functionID <= 0.5) {
  252. sim.moveBot(4);
  253. } else {
  254. sim.moveBot(0);
  255. }
  256.  
  257. if (functionID >= 0.5 && functionID <= 0.6) {
  258. hook(true);
  259. } else if (functionID >= 0.6 && functionID <= 0.7) {
  260. hook(false);
  261. } else if (functionID >= 0.7 && functionID <= 0.8) {
  262. topArm(true);
  263. } else if (functionID >= 0.8 && functionID <= 0.9) {
  264. topArm(false);
  265. }
  266. }
  267. };
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement