Advertisement
pasholnahuy

Untitled

Jan 17th, 2024
79
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.82 KB | None | 0 0
  1. #pragma once
  2.  
  3. #include "Layer.h"
  4. #include "Score_Func.h"
  5. #include <Eigen/Eigen>
  6. #include <EigenRand/EigenRand>
  7. #include <cmath>
  8. #include <iostream>
  9. #include <random>
  10. #include <utility>
  11.  
  12. namespace network {
  13.  
  14. struct Values {
  15. using MatrixXd = Eigen::MatrixXd;
  16. template <class T> using vector = std::vector<T>;
  17. vector<MatrixXd> in;
  18. vector<MatrixXd> out;
  19. };
  20.  
  21. class Network {
  22. template <class T> using vector = std::vector<T>;
  23.  
  24. public:
  25. using MatrixXd = Eigen::MatrixXd;
  26. using VectorXd = Eigen::VectorXd;
  27. using PermutationMatrix =
  28. Eigen::PermutationMatrix<Eigen::Dynamic, Eigen::Dynamic>;
  29.  
  30. Network(std::initializer_list<int> dimensions,
  31. std::initializer_list<Threshold_Id> threshold_id);
  32.  
  33. Values Forward_Prop(const MatrixXd &start_vec);
  34. VectorXd Back_Prop(const MatrixXd &start_vec, const MatrixXd &reference,
  35. const Score_Func &score_func, double coef);
  36.  
  37. VectorXd Back_Prop_SGD(const MatrixXd &start_batch, const MatrixXd &reference,
  38. const Score_Func &score_func, int iter_num);
  39. VectorXd Back_Prop_MBGD(const MatrixXd &start_batch,
  40. const MatrixXd &reference,
  41. const Score_Func &score_func, int iter_num);
  42.  
  43. void TrainSGD(const MatrixXd &start_batch, const MatrixXd &reference,
  44. const Score_Func &score_func, double needed_accuracy,
  45. int max_epochs);
  46. void TrainBGD(const MatrixXd &start_batch, const MatrixXd &reference,
  47. const Score_Func &score_func, int cols_in_minibatch,
  48. double needed_accuracy, int max_epochs);
  49.  
  50. private:
  51. static VectorXd Cols_Mean(const MatrixXd &x);
  52.  
  53. vector<Layer> layers_;
  54. vector<Threshold_Id> threshold_id_;
  55. inline static std::minstd_rand index_generator;
  56. };
  57. } // namespace network
  58.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement