Advertisement
pasholnahuy

Neurolinks

Dec 6th, 2023
677
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 3.15 KB | None | 0 0
  1. #include <iostream>
  2. #include <cmath>
  3. #include "eigen/Eigen/Core"
  4.  
  5. using Eigen::MatrixXd;
  6. using Eigen::VectorXd;
  7.  
  8. enum class Threshold_Id {
  9.     Sigmoid,
  10.     ReLu
  11. };
  12.  
  13. struct Threshold_Database {
  14.     template<Threshold_Id>
  15.     static double evaluate_0(double);
  16.  
  17.     template<Threshold_Id>
  18.     static double evaluate_1(double);
  19.  
  20.     template<>
  21.     inline double evaluate_0<Threshold_Id::Sigmoid>(double x) {
  22.         return 1. / (1. + std::exp(-x));
  23.     }
  24.  
  25.     template<>
  26.     inline double evaluate_1<Threshold_Id::Sigmoid>(double x) {
  27.         return std::exp(-x) * evaluate_0<Threshold_Id::Sigmoid>(x) * evaluate_0<Threshold_Id::Sigmoid>(x);
  28.     }
  29.  
  30.  
  31.     template<>
  32.     inline double evaluate_0<Threshold_Id::ReLu>(double x) {
  33.         return x > 0 ? x : 0;
  34.     }
  35.  
  36.     template<>
  37.     inline double evaluate_1<Threshold_Id::ReLu>(double x) {
  38.         return x > 0 ? 1 : 0;
  39.     }
  40.  
  41. };
  42.  
  43. template<Threshold_Id>
  44. struct F_ID {
  45. };
  46.  
  47. class Threshold_Func {
  48.     using FunctionType = std::function<double(double)>;
  49. public:
  50.     Threshold_Func(FunctionType evaluate_0, FunctionType evaluate_1) : evaluate_0_(std::move(evaluate_0)),
  51.                                                                        evaluate_1_(std::move(evaluate_1)) {
  52.     }
  53.  
  54.     template<Threshold_Id Id>
  55.     Threshold_Func(F_ID<Id>): evaluate_0_(Threshold_Database::evaluate_0<Id>),
  56.                               evaluate_1_(Threshold_Database::evaluate_1<Id>) {
  57.  
  58.     }
  59.  
  60.     double evaluate_0(double x) const {
  61.         return evaluate_0_(x);
  62.     }
  63.  
  64.     double evaluate_1(double x) const {
  65.         return evaluate_1_(x);
  66.     }
  67.  
  68.     VectorXd apply(const VectorXd &vec) const {
  69.         return vec.unaryExpr([this](double x) { return evaluate_0(x); });
  70.     }
  71.  
  72.     VectorXd derive(const VectorXd &vec) const {
  73.         return vec.unaryExpr([this](double x) { return evaluate_1(x); });
  74.  
  75.     }
  76.  
  77. private:
  78.     FunctionType evaluate_0_;
  79.     FunctionType evaluate_1_;
  80. };
  81.  
  82. class Layer {
  83. public:
  84.     VectorXd apply(const VectorXd &x) const { // vector of values
  85.         return sigma_.apply(A_ * x + b_);
  86.     }
  87.  
  88.     MatrixXd derive(const VectorXd &vec) const { // vec is a matrix of y_i = (Ax + b)_i - result of apply
  89.         return sigma_.derive(vec).asDiagonal();
  90.     }
  91.  
  92.     MatrixXd gradA(const VectorXd &x, const VectorXd &u, const VectorXd &vec) const { // u is a gradient vector
  93.         return derive(vec) * u.transpose() * x.transpose();
  94.     }
  95.  
  96.     MatrixXd gradb(const VectorXd &u, const VectorXd &vec) const {
  97.         return derive(vec) * u.transpose();
  98.     }
  99.  
  100.     VectorXd gradx(const VectorXd &x, const VectorXd &u, const VectorXd &vec) const {
  101.         return (A_.transpose() * derive(vec) * u.transpose()).transpose();
  102.     }
  103.  
  104. private:
  105.     Threshold_Func sigma_;
  106.     MatrixXd A_;
  107.     VectorXd b_;
  108. };
  109.  
  110. double Score(MatrixXd res, MatrixXd reference) {
  111.     return (res - reference).cwiseAbs().sum();
  112. }
  113.  
  114. int main() {
  115.     Threshold_Func(F_ID<Threshold_Id::Sigmoid>());
  116.  
  117.     std::cout << "Hello, World!" << std::endl;
  118.     using Vector3f = Eigen::Matrix<float, 3, 1>;
  119.     Vector3f a;
  120.     for (auto el: a) {
  121.         std::cout << el << '\n';
  122.     }
  123.     return 0;
  124. }
  125.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement