Advertisement
riabcis

Untitled

Apr 5th, 2018
596
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 12.82 KB | None | 0 0
  1. // Ann.cpp : Defines the entry point for the console application.
  2. //
  3.  
  4. #include "stdafx.h"
  5. # include <cmath>
  6. #include <math.h>
  7. #include <vector>
  8. #include <iostream>
  9. #include <iomanip>
  10. #include <fstream>
  11. #include <string>
  12. #include <array>
  13. using namespace std;
  14. //Generating random number: either 0 or 1, uniform distribution, for XOR operation. Can remove later if using data from files.
  15. int randint();
  16. double f(double x);
  17. double f_deriv(double x);
  18. double gL(double a, double z, double t);
  19. double w_gradient(int layer_id, int w_i, int w_j, double *a_arr, int *s, double *gll);
  20. double delta_w(double grad, double dw);
  21. void calc_gjl(double *a_arr, double *z_arr, double *t_arr, double *w_arr, int *s, int *sw, int L, int *l, double *gll);
  22. const double ETA = 0.1;
  23. const double ALPHA = 0.15;
  24. void file(string filename, double *a, int dydis);
  25.  
  26. struct Topology
  27. {
  28.     std::vector<int> l;//kiekiai sluoksnyje
  29. } topolygy;
  30.  
  31. struct Sample
  32. {
  33.     double input[2];
  34.     double output[2];
  35.     //double * input=new double[topolygy.l.at(0)];
  36.     //double * output=new double[topolygy.l.at(topolygy.l.size()-1)];
  37.     //Sample(double a[],double b[]) {
  38.     //  input = a;
  39.     //  output = b;
  40.     //};
  41.     //temp*******
  42.     string ToString() {
  43.         string str;
  44.         str = "input: " + to_string(input[0]) + " " + to_string(input[1]) + " output: " + to_string(output[0]) + " " + to_string(output[1]) + "\n";
  45.         return str;
  46.     }
  47. };
  48. class Data
  49. {
  50. public:
  51.     int getNumberOfInputs() { return inputs; }
  52.     int getNumberOfOutputs() { return outputs; }
  53.  
  54.     double * getInput(int index);
  55.  
  56.     double * getOutput(int index);
  57.  
  58.     int getNumberOfSamples() { return samples; }
  59.  
  60.     void addSample(Sample sample);
  61.  
  62.     void setSizes(int input_size, int output_size);
  63.  
  64. protected:
  65.     std::vector<Sample> data;
  66.     int inputs;
  67.     int outputs;
  68.     int samples = 0;
  69. };
  70.  
  71. class XOR : public Data
  72. {
  73. public:
  74.     void generate(int n);
  75.  
  76.     XOR()
  77.     {
  78.         inputs = 2;
  79.         outputs = 2;
  80.         samples = 0;
  81.     }
  82.     void printInputs(int index);
  83.  
  84.     void printOutputs(int index);
  85. };
  86.  
  87. class AnnBase {
  88. public:
  89.     virtual void prepare(Topology top) = 0;
  90.     virtual void init(Topology top, double w_arr_1[]) = 0;
  91.     virtual void train(double *a, double *b) = 0;
  92.     virtual void feedForward(double *a, double *b) = 0;
  93.     virtual void destroy() = 0;
  94. private:
  95.     virtual void calc_feedForward() = 0;
  96. };
  97.  
  98. class AnnSerialDBL : public AnnBase {
  99. public:
  100.     void prepare(Topology top);
  101.     void init(Topology top, double w_arr_1[]);
  102.     void train(double *a, double *b);
  103.     void feedForward(double *a, double *b);
  104.     void destroy();
  105. private:
  106.     void calc_feedForward();
  107. public:
  108.     int z_count;//temp var to keep the length of z, so z could be reset for calcs.
  109.     int inputs;
  110.     int outputs;
  111.     int L;
  112.     int * l;
  113.     int * s;
  114.     double * a_arr;
  115.     double * z_arr;
  116.     int * W;
  117.     int * sw;
  118.     double * w_arr;
  119.     double * dw_arr;
  120.     double * t_arr;
  121.     double * gjl;
  122. };
  123. void print_all(AnnSerialDBL SerialDBL, int sum, int mult, int i);
  124. void read_W(string filename, double *w_arr);
  125. void calc_sizes(int &sum, int &mult, Topology top);
  126. int main()
  127. {
  128.     //15+24+10  49
  129.     topolygy.l.push_back(2);
  130.     topolygy.l.push_back(5);
  131.     topolygy.l.push_back(4);
  132.     //topolygy.l.push_back(7);
  133.     topolygy.l.push_back(2);
  134.     AnnSerialDBL SerialDBL;
  135.  
  136.    
  137.     int sum = 0;
  138.     int mult = 0;
  139.     calc_sizes(sum, mult, topolygy);
  140.  
  141.     SerialDBL.prepare(topolygy);
  142.  
  143.     double * a1=new double[mult];
  144.     read_W("W_ARR.txt", a1);
  145.  
  146.     SerialDBL.init(topolygy, a1);
  147.     delete[] a1;
  148.     a1 = NULL;
  149.  
  150.     print_all(SerialDBL, sum, mult, 0);
  151.  
  152.     XOR xo;
  153.     xo.generate(100000);
  154.     SerialDBL.train(xo.getInput(0), xo.getOutput(0));
  155.     print_all(SerialDBL, sum, mult, 1);
  156.     for (int i = 1; i < xo.getNumberOfSamples(); i++) {
  157.         SerialDBL.train(xo.getInput(i), xo.getOutput(i));
  158.     }
  159.  
  160.     //Checking results(all combinations 0 and 1)
  161.     for (double i = 0; i < 2; i++) {
  162.         for (double j = 0; j < 2; j++) {
  163.             double input[] = { i ,j };
  164.             double output[] = { 0,0 };
  165.             SerialDBL.feedForward(input, output);
  166.  
  167.             cout << endl << "input : " << input[0] << "   " << input[1] << endl;
  168.             cout << endl << "output: " << output[0] << "   " << output[1] << endl << endl;
  169.             cout << "---------------------------------------------------" << endl;
  170.         }
  171.     }
  172.  
  173.     //Checking results(all combinations 0 and 1)
  174.     for (double i = 0; i < 0; i++) {
  175.         double input[] = { randint()*1.0, randint()*1.0 };
  176.         double output[] = { 0,0 };
  177.         SerialDBL.feedForward(input, output);
  178.  
  179.         cout << endl << "input : " << input[0] << "   " << input[1] << endl;
  180.         cout << endl << "output: " << output[0] << "   " << output[1] << endl << endl;
  181.         cout << "---------------------------------------------------" << endl;
  182.     }
  183.  
  184.     print_all(SerialDBL, sum, mult, 2);
  185.  
  186.  
  187.     SerialDBL.destroy();
  188.  
  189.     int a;
  190.     cin >> a;
  191.  
  192.     return 0;
  193. }
  194. //returns random int, either 0 or 1
  195. int randint() {
  196.     double r = ((double)rand() / (RAND_MAX));
  197.     int a = 0;
  198.     if (r > 0.5) {
  199.         a = 1;
  200.     }
  201.     else
  202.     {
  203.         a = 0;
  204.     }
  205.     return a;
  206. }
  207.  
  208. double f(double x) {
  209.     double y = 1 + exp(-x);
  210.     //temp*********************
  211.     if (y == 0) {
  212.         cout << "Error 1";
  213.     }
  214.     if ((y - 1) == 0) {
  215.         //cout << "Error 2";
  216.     }
  217.     //temp**********************
  218.     return 1 / y;
  219. }
  220.  
  221. double f_deriv(double x) {
  222.     //Temp**********
  223.     double y = pow((1 + exp(-x)), 2);
  224.     double z = exp(-x);
  225.     if (y == 0) {
  226.         cout << "Error 3";
  227.     }
  228.     if (z == 0) {
  229.     //  cout << "Error 4";
  230.     }
  231.     //temp**********************
  232.     return exp(-x) / pow((1 + exp(-x)), 2);
  233. }
  234.  
  235. double gL(double a, double z, double t) {
  236.     double w = f_deriv(z) * (a - t);
  237.     //cout << "z: " << z << " a: " << a << " t: " << t << endl;
  238.     return w;
  239. }
  240.  
  241. double w_gradient(int layer_id, int w_i, int w_j, double *a_arr, int *s, double *gll) {
  242.     return a_arr[s[layer_id] + w_i] * gll[s[layer_id+1] + w_j];
  243. }
  244.  
  245. double delta_w(double grad, double dw) {
  246.     return (-ETA)*grad + ALPHA*dw;
  247. }
  248.  
  249. void calc_gjl(double *a_arr, double *z_arr, double *t_arr, double *w_arr, int *s, int *sw, int L, int *l, double *gll)
  250. {
  251.     for (int i = L-2; i >= 0; i--) {
  252.         for (int j = 0; j < l[i + 1] - 1; j++) {
  253.             if (L - 2 == i) {
  254.                 gll[s[i + 1] + j] = gL(a_arr[s[i + 1] + j], z_arr[s[i + 1] + j], t_arr[j]);
  255.             }
  256.             else {
  257.                 gll[s[i + 1] + j] = f_deriv(z_arr[s[i + 1] + j]);
  258.                 double sum = 0;
  259.                 for (int k = 0; k < l[i+2]-1; k++) {
  260.                     sum += w_arr[sw[i + 1] + k*(l[i + 2] - 1) + j] * gll[s[i + 2] + j];
  261.                 }
  262.                 gll[s[i + 1] + j] *= sum;
  263.             }
  264.         }
  265.     }
  266. }
  267.  
  268. //*********
  269. double * Data::getInput(int index)
  270. {
  271.     return data[index].input;
  272. }
  273.  
  274. double * Data::getOutput(int index)
  275. {
  276.     return data[index].output;
  277. }
  278.  
  279. void Data::addSample(Sample sample)
  280. {
  281.     data.push_back(sample);
  282.     samples++;
  283.     //cout << sample.ToString();
  284. }
  285.  
  286. void Data::setSizes(int input_size, int output_size)
  287. {
  288.     inputs = input_size;
  289.     outputs = output_size;
  290. }
  291.  
  292. //****************
  293. void XOR::generate(int n)
  294. {
  295.     for (int i = 0; i < n / 4; i++)
  296.     {
  297.         //double input1 = randint();
  298.         //double input2 = randint();
  299.         for (double j = 0; j < 2; j++) {
  300.             for (double k = 0; k < 2; k++) {
  301.                 double output1 = j == k;
  302.                 double output2 = j != k;
  303.                 double input[] = { j,k };
  304.                 double output[] = { output1, output2 };
  305.                 //Sample s = { input,output };
  306.                 //addSample({ input,output });
  307.                 addSample({ {j,k} ,{output1,output2} });
  308.             }
  309.         }
  310.  
  311.     }
  312. }
  313.  
  314. void XOR::printInputs(int index)
  315. {
  316.     cout << index << " index inputs: " << endl;
  317.     for (int i = 0; i < inputs; i++)
  318.     {
  319.         cout << getInput(index)[i] << " ";
  320.     }
  321.     cout << endl;
  322. }
  323.  
  324. void XOR::printOutputs(int index)
  325. {
  326.     cout << index << " index outputs: " << endl;
  327.     for (int i = 0; i < outputs; i++)
  328.     {
  329.         cout << fixed << setprecision(2) << data[index].output[i] << " ";
  330.     }
  331.     cout << endl;
  332. }
  333.  
  334. //****************
  335. void AnnSerialDBL::prepare(Topology top)
  336. {
  337.     inputs = top.l.at(0);
  338.     outputs = top.l.at(top.l.size() - 1);
  339.  
  340.     l = new int[top.l.size()];
  341.     s = new int[top.l.size()];
  342.  
  343.     int sum = 0;
  344.     int mult = 0;
  345.     for (int i = 0; i < top.l.size(); i++) {
  346.         sum += top.l.at(i) + 1;
  347.     }
  348.     z_count = sum;
  349.     for (int i = 0; i < top.l.size() - 1; i++) {
  350.         mult += (top.l.at(i) + 1)*top.l.at(i + 1);
  351.     }
  352.     a_arr = new double[sum];
  353.     z_arr = new double[sum];
  354.  
  355.     W = new int[top.l.size()];
  356.     sw = new int[top.l.size()];
  357.  
  358.     w_arr = new double[mult];
  359.     dw_arr = new double[mult];
  360.  
  361.     t_arr = new double[top.l.at(top.l.size() - 1)];
  362.  
  363.     gjl = new double[sum];
  364. }
  365.  
  366. void AnnSerialDBL::init(Topology top, double w_arr_1[] = NULL)
  367. {
  368.     L = top.l.size();
  369.     //Neuronu kiekiai sluoksnyje
  370.     for (int i = 0; i < top.l.size(); i++) {
  371.         l[i] = top.l.at(i) + 1;
  372.     }
  373.  
  374.     //Sluoksniu pradzios indeksai
  375.     for (int i = 0; i < top.l.size(); i++) {
  376.         s[i] = 0;
  377.         for (int j = i; j > 0; j--) {
  378.             s[i] += l[j - 1];
  379.         }
  380.     }
  381.  
  382.     //Bias neuronai
  383.     for (int i = 0; i < top.l.size() - 1; i++) {
  384.         a_arr[s[i + 1] - 1] = 1;
  385.     }
  386.  
  387.     //Svoriu kiekiai l-ame sluoksnyje
  388.     for (int i = 0; i < top.l.size() - 1; i++) {
  389.         W[i] = l[i] * (l[i + 1] - 1);
  390.         //cout << "Svoriu sk: " << W[i] << " Pradzios index: ";
  391.         sw[i] = 0;
  392.         if (i != 0) {
  393.             for (int j = 0; j < i; j++) {
  394.                 sw[i] += W[j];
  395.             }
  396.         }
  397.         if (w_arr_1 == NULL) {
  398.             for (int j = 0; j < W[i]; j++) {
  399.                 w_arr[sw[i] + j] = (double)rand() / double(RAND_MAX);
  400.                 //cout << w_arr[sw[i] + j]<< endl;
  401.                 dw_arr[sw[i] + j] = 0;
  402.             }
  403.         }
  404.         else {
  405.             //w_arr = w_arr_1; //ar reikia pokycius issisaugoti irgi?
  406.             for (int j = 0; j < W[i]; j++) {
  407.                 w_arr[sw[i] + j] = w_arr_1[sw[i] + j];
  408.                 //cout << w_arr[sw[i] + j]<< endl;
  409.                 dw_arr[sw[i] + j] = 0;
  410.             }
  411.         }
  412.  
  413.         //cout << sw[i] << " " << endl;
  414.     }
  415. }
  416.  
  417. void AnnSerialDBL::train(double *a, double *b)
  418. {
  419.     for (int i = 0; i < inputs; i++) {
  420.         a_arr[i] = a[i];
  421.     }
  422.  
  423.     for (int j = 0; j < z_count; j++) {
  424.         z_arr[j] = 0;
  425.     }
  426.  
  427.     calc_feedForward();
  428.  
  429.     for (int i = 0; i < outputs; i++) {
  430.         t_arr[i] = b[i];
  431.     }
  432.     calc_gjl(a_arr, z_arr, t_arr, w_arr, s, sw, L,l, gjl);
  433.  
  434.     //back propogation:
  435.     for (int i = L - 2; i >= 0; i--) {//per sluoksnius
  436.         for (int j = 0; j < l[i]; j++) {//per neuronus
  437.             for (int k = 0; k < l[i + 1] - 1; k++) {//per kito sluoksnio neuronus
  438.                 dw_arr[sw[i] + k + j*(l[i + 1] - 1)] = delta_w(w_gradient(i, j, k, a_arr, s,gjl), dw_arr[sw[i] + k + j*(l[i + 1] - 1)]);
  439.                 w_arr[sw[i] + k + j*(l[i + 1] - 1)] += dw_arr[sw[i] + k + j*(l[i + 1] - 1)];
  440.             }
  441.         }
  442.     }
  443. }
  444.  
  445. void AnnSerialDBL::feedForward(double *a, double *b)
  446. {
  447.     for (int i = 0; i < inputs; i++) {
  448.         a_arr[i] = a[i];
  449.     }
  450.  
  451.     for (int j = 0; j < z_count; j++) {
  452.         z_arr[j] = 0;
  453.     }
  454.  
  455.     calc_feedForward();
  456.  
  457.     double max = 0;
  458.     int index = 0;
  459.  
  460.     for (int i = 0; i<outputs; i++) {
  461.         cout << " a reiksmes: " << a_arr[s[L - 1] + i] << endl;
  462.         if (max < a_arr[s[L - 1] + i]) {
  463.             max = a_arr[s[L - 1] + i];
  464.             index = i;
  465.         }
  466.     }
  467.     for (int i = 0; i < outputs; i++) {
  468.         if (i == index) {
  469.             b[i] = 1;
  470.         }
  471.         else {
  472.             b[i] = 0;
  473.         }
  474.     }
  475. }
  476.  
  477. void AnnSerialDBL::calc_feedForward()
  478. {
  479.     for (int i = 0; i < L - 1; i++) {//per sluoksnius einu+
  480.         for (int j = 0; j < l[i]; j++) { //kiek neuronu sluoksnyje+
  481.             for (int k = 0; k < l[i + 1] - 1; k++) {//per sekancio sluoksnio z+
  482.                 z_arr[s[i + 1] + k] += w_arr[sw[i] + k + j*(l[i + 1] - 1)] * a_arr[s[i] + j];
  483.                 //  cout << "w: "<< w_arr[sw[i] + k + j*(l[i + 1] - 1)] << endl;
  484.                 //  cout << "a: " << a_arr[s[i] + j] << endl;
  485.                 //  cout << "z reiksmes: " << z_arr[s[i+1] + k] << endl;
  486.                 //  cout << endl;
  487.             }
  488.         }
  489.         for (int k = 0; k < l[i + 1] - 1; k++) {//per sekancio sluoksnio z
  490.             a_arr[s[i + 1] + k] = f(z_arr[s[i + 1] + k]);
  491.             //  cout << s[i + 1] + k << " a reiksmes: " << a_arr[s[i + 1] + k] << endl;
  492.         }
  493.     }
  494. }
  495.  
  496. void AnnSerialDBL::destroy()
  497. {
  498.     delete[] l;
  499.     l = NULL;
  500.     delete[] s;
  501.     s = NULL;
  502.  
  503.     delete[] a_arr;
  504.     a_arr = NULL;
  505.     delete[] z_arr;
  506.     z_arr = NULL;
  507.  
  508.     delete[] W;
  509.     W = NULL;
  510.     delete[] sw;
  511.     sw = NULL;
  512.  
  513.     delete[] w_arr;
  514.     w_arr = NULL;
  515.     delete[] dw_arr;
  516.     dw_arr = NULL;
  517.  
  518.     delete[] t_arr;
  519.     t_arr = NULL;
  520.  
  521.     delete[] gjl;
  522.     gjl = NULL;
  523. }
  524.  
  525. void file(string filename, double *a,int dydis) {
  526.     ofstream myfile;
  527.     myfile.open(filename);
  528.     myfile << filename << ";" << endl;
  529.     for (int i = 0; i < dydis; i++) {
  530.         myfile << a[i] << ";" << endl;
  531.     }
  532.     myfile.close();
  533. }
  534.  
  535. void print_all(AnnSerialDBL SerialDBL, int sum, int mult,int i) {
  536.     file("(" + to_string(5 * i + 1) + ")a_arr.csv", SerialDBL.a_arr, sum);
  537.     file("(" + to_string(5 * i + 2) + ")z_arr.csv", SerialDBL.z_arr, sum);
  538.     file("(" + to_string(5 * i + 3) + ")w_arr.csv", SerialDBL.w_arr, mult);
  539.     file("(" + to_string(5 * i + 4) + ")dw_arr.csv", SerialDBL.dw_arr, mult);
  540.     file("(" + to_string(5 * i + 5) + ")g_arr.csv", SerialDBL.gjl, sum);
  541. }
  542.  
  543. void read_W(string filename,double *w_arr) {
  544.     ifstream myReadFile;
  545.     myReadFile.open(filename);
  546.     string a;
  547.     if (myReadFile.is_open()) {
  548.         int i = 0;
  549.         while (!myReadFile.eof()) {
  550.             myReadFile >> a;
  551.             w_arr[i++] = stod(a);
  552.         }
  553.     }
  554.     myReadFile.close();
  555. }
  556.  
  557. void calc_sizes(int &sum, int &mult, Topology top) {
  558.     for (int i = 0; i < topolygy.l.size(); i++) {
  559.         sum += topolygy.l.at(i) + 1;
  560.     }
  561.     for (int i = 0; i < topolygy.l.size() - 1; i++) {
  562.         mult += (topolygy.l.at(i) + 1)*topolygy.l.at(i + 1);
  563.     }
  564. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement