Advertisement
riabcis

Untitled

Apr 6th, 2018
324
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 12.92 KB | None | 0 0
  1. // Ann.cpp : Defines the entry point for the console application.
  2. //
  3.  
  4. #include "stdafx.h"
  5. # include <cmath>
  6. #include <math.h>
  7. #include <vector>
  8. #include <iostream>
  9. #include <iomanip>
  10. #include <fstream>
  11. #include <string>
  12. #include <array>
  13. using namespace std;
  14. //Generating random number: either 0 or 1, uniform distribution, for XOR operation. Can remove later if using data from files.
  15. int randint();
  16. double f(double x);
  17. double f_deriv(double x);
  18. double gL(double a, double z, double t);
  19. double w_gradient(int layer_id, int w_i, int w_j, double *a_arr, int *s, double *gll);
  20. double delta_w(double grad, double dw);
  21. void calc_gjl(double *a_arr, double *z_arr, double *t_arr, double *w_arr, int *s, int *sw, int L, int *l, double *gll);
  22. const double ETA = 0.1;
  23. const double ALPHA = 0;
  24. void file(string filename, double *a, int dydis);
  25.  
  26. struct Topology
  27. {
  28.     std::vector<int> l;//kiekiai sluoksnyje
  29. } topolygy;
  30.  
  31. struct Sample
  32. {
  33.     double * input;
  34.     double * output;
  35.     string ToString() {
  36.         string str;
  37.         str = "input: " + to_string(input[0]) + " " + to_string(input[1]) + " output: " + to_string(output[0]) + " " + to_string(output[1]) + "\n";
  38.         return str;
  39.     }
  40. };
  41. class Data
  42. {
  43. public:
  44.     int getNumberOfInputs() { return inputs; }
  45.     int getNumberOfOutputs() { return outputs; }
  46.  
  47.     double * getInput(int index);
  48.  
  49.     double * getOutput(int index);
  50.  
  51.     int getNumberOfSamples() { return samples; }
  52.  
  53.     void addSample(Sample sample);
  54.  
  55.     void setSizes(int input_size, int output_size);
  56.  
  57. protected:
  58.     std::vector<Sample> data;
  59.     int inputs;
  60.     int outputs;
  61.     int samples = 0;
  62. };
  63.  
  64. class XOR : public Data
  65. {
  66. public:
  67.     void generate(int n);
  68.  
  69.     XOR()
  70.     {
  71.         inputs = 2;
  72.         outputs = 2;
  73.         samples = 0;
  74.     }
  75.     void printInputs(int index);
  76.  
  77.     void printOutputs(int index);
  78. };
  79.  
  80. class AnnBase {
  81. public:
  82.     virtual void prepare(Topology top) = 0;
  83.     virtual void init(Topology top, double w_arr_1[]) = 0;
  84.     virtual void train(double *a, double *b) = 0;
  85.     virtual void feedForward(double *a, double *b) = 0;
  86.     virtual void destroy() = 0;
  87. private:
  88.     virtual void calc_feedForward() = 0;
  89. };
  90.  
  91. class AnnSerialDBL : public AnnBase {
  92. public:
  93.     void prepare(Topology top);
  94.     void init(Topology top, double w_arr_1[]);
  95.     void train(double *a, double *b);
  96.     void feedForward(double *a, double *b);
  97.     void destroy();
  98. private:
  99.     void calc_feedForward();
  100. public:
  101.     int z_count;//temp var to keep the length of z, so z could be reset for calcs.
  102.     int inputs;
  103.     int outputs;
  104.     int L;
  105.     int * l;
  106.     int * s;
  107.     double * a_arr;
  108.     double * z_arr;
  109.     int * W;
  110.     int * sw;
  111.     double * w_arr;
  112.     double * dw_arr;
  113.     double * t_arr;
  114.     double * gjl;
  115.  
  116.     //del grafiku
  117.     double * E_error = new double[5000];
  118.     double * w_diffrence = new double[5000];
  119.     int kiek_yra = 0;
  120. };
  121. void print_all(AnnSerialDBL & SerialDBL, int sum, int mult, int i);
  122. void read_W(string filename, double *w_arr);
  123. void calc_sizes(int &sum, int &mult, Topology top);
  124. int main()
  125. {
  126.     //15+24+10  49
  127.     topolygy.l.push_back(2);
  128.     topolygy.l.push_back(5);
  129.     topolygy.l.push_back(4);
  130.     topolygy.l.push_back(2);
  131.     AnnSerialDBL  SerialDBL;
  132.  
  133.    
  134.     int sum = 0;
  135.     int mult = 0;
  136.     calc_sizes(sum, mult, topolygy);
  137.  
  138.     SerialDBL.prepare(topolygy);
  139.  
  140.     double * a1=new double[mult];
  141.     //read_W("W_ARR.txt", a1);
  142.     SerialDBL.init(topolygy, NULL);
  143.     delete[] a1;
  144.     a1 = NULL;
  145.  
  146. //  print_all(SerialDBL, sum, mult, 0);
  147.  
  148.     XOR xo;
  149.     xo.generate(30000);
  150.     SerialDBL.train(xo.getInput(0), xo.getOutput(0));
  151.  
  152. //  print_all(SerialDBL, sum, mult, 1);
  153.  
  154.     for (int i = 1; i < xo.getNumberOfSamples(); i++) {
  155.         SerialDBL.train(xo.getInput(i), xo.getOutput(i));
  156.     }
  157.  
  158.     //Checking results(all combinations 0 and 1)
  159.     for (double i = 0; i < 2; i++) {
  160.         for (double j = 0; j < 2; j++) {
  161.             double input[] = { i ,j };
  162.             double output[] = { 0,0 };
  163.             SerialDBL.feedForward(input, output);
  164.  
  165.             cout << endl << "input : " << input[0] << "   " << input[1] << endl;
  166.             cout << endl << "output: " << output[0] << "   " << output[1] << endl << endl;
  167.             cout << "---------------------------------------------------" << endl;
  168.         }
  169.     }
  170.  
  171.     //Checking results(all combinations 0 and 1)
  172.     for (double i = 0; i < 0; i++) {
  173.         double input[] = { randint()*1.0, randint()*1.0 };
  174.         double output[] = { 0,0 };
  175.         SerialDBL.feedForward(input, output);
  176.  
  177.         cout << endl << "input : " << input[0] << "   " << input[1] << endl;
  178.         cout << endl << "output: " << output[0] << "   " << output[1] << endl << endl;
  179.         cout << "---------------------------------------------------" << endl;
  180.     }
  181.  
  182. //  print_all(SerialDBL, sum, mult, 2);
  183.  
  184. //  file("(20)E_error.csv", SerialDBL.E_error, 5000);
  185. //  file("(21)w_diffrence.csv", SerialDBL.w_diffrence, 5000);
  186.     SerialDBL.destroy();
  187.  
  188.     int a;
  189.     cin >> a;
  190.  
  191.     return 0;
  192. }
  193. //returns random int, either 0 or 1
  194. int randint() {
  195.     double r = ((double)rand() / (RAND_MAX));
  196.     int a = 0;
  197.     if (r > 0.5) {
  198.         a = 1;
  199.     }
  200.     else
  201.     {
  202.         a = 0;
  203.     }
  204.     return a;
  205. }
  206.  
  207. double f(double x) {
  208.     double y = 1 + exp(-x);
  209.     //temp*********************
  210.     if (y == 0) {
  211.         cout << "Error 1";
  212.     }
  213.     if ((y - 1) == 0) {
  214.         //cout << "Error 2";
  215.     }
  216.     //temp**********************
  217.     return 1 / y;
  218. }
  219.  
  220. double f_deriv(double x) {
  221.     //Temp**********
  222.     double y = pow((1 + exp(-x)), 2);
  223.     double z = exp(-x);
  224.     if (y == 0) {
  225.         cout << "Error 3";
  226.     }
  227.     if (z == 0) {
  228.     //  cout << "Error 4";
  229.     }
  230.     //temp**********************
  231.     return exp(-x) / pow((1 + exp(-x)), 2);
  232. }
  233.  
  234. double gL(double a, double z, double t) {
  235.     double w = f_deriv(z) * (a - t);
  236.     //cout << "z: " << z << " a: " << a << " t: " << t << endl;
  237.     return w;
  238. }
  239.  
  240. double w_gradient(int layer_id, int w_i, int w_j, double *a_arr, int *s, double *gll) {
  241.     return a_arr[s[layer_id] + w_i] * gll[s[layer_id+1] + w_j];
  242. }
  243.  
  244. double delta_w(double grad, double dw) {
  245.     return (-ETA)*grad + ALPHA*dw;
  246. }
  247.  
  248. void calc_gjl(double *a_arr, double *z_arr, double *t_arr, double *w_arr, int *s, int *sw, int L, int *l, double *gll)
  249. {
  250.     for (int i = L-1; i >= 0; i--) {
  251.         for (int j = 0; j < l[i]; j++) {
  252.             if (L - 1 == i) {
  253.                 gll[s[i] + j] = gL(a_arr[s[i] + j], z_arr[s[i] + j], t_arr[j]);
  254.             }
  255.             else {
  256.                 gll[s[i] + j] = f_deriv(z_arr[s[i] + j]);
  257.                 double sum = 0;
  258.                 for (int k = 0; k < l[i+1]-1; k++) {
  259.                     sum += w_arr[sw[i] + j*(l[i + 1] - 1) + k] * gll[s[i + 1] + k];
  260.                 }
  261.                 gll[s[i] + j] *= sum;
  262.             }
  263.         }
  264.     }
  265. }
  266.  
  267. //*********
  268. double * Data::getInput(int index)
  269. {
  270.     return data[index].input;
  271. }
  272.  
  273. double * Data::getOutput(int index)
  274. {
  275.     return data[index].output;
  276. }
  277.  
  278. void Data::addSample(Sample sample)
  279. {
  280.     data.push_back(sample);
  281.     samples++;
  282.     //cout << sample.ToString();
  283. }
  284.  
  285. void Data::setSizes(int input_size, int output_size)
  286. {
  287.     inputs = input_size;
  288.     outputs = output_size;
  289. }
  290.  
  291. //****************
  292. void XOR::generate(int n)
  293. {
  294.     for (int i = 0; i < n / 4; i++)
  295.     {
  296.         for (double j = 0; j < 2; j++) {
  297.             for (double k = 0; k < 2; k++) {
  298.                 double * input = new double[2];
  299.                 input[0]= j;
  300.                 input[1] = k;
  301.                 double * output = new double[2];
  302.                 output[0] = j == k;
  303.                 output[1] = j != k;
  304.                 addSample({input,output});
  305.             }
  306.         }
  307.  
  308.     }
  309. }
  310.  
  311. void XOR::printInputs(int index)
  312. {
  313.     cout << index << " index inputs: " << endl;
  314.     for (int i = 0; i < inputs; i++)
  315.     {
  316.         cout << getInput(index)[i] << " ";
  317.     }
  318.     cout << endl;
  319. }
  320.  
  321. void XOR::printOutputs(int index)
  322. {
  323.     cout << index << " index outputs: " << endl;
  324.     for (int i = 0; i < outputs; i++)
  325.     {
  326.         cout << fixed << setprecision(2) << data[index].output[i] << " ";
  327.     }
  328.     cout << endl;
  329. }
  330.  
  331. //****************
  332. void AnnSerialDBL::prepare(Topology top)
  333. {
  334.     inputs = top.l.at(0);
  335.     outputs = top.l.at(top.l.size() - 1);
  336.  
  337.     l = new int[top.l.size()];
  338.     s = new int[top.l.size()];
  339.  
  340.     int sum = 0;
  341.     int mult = 0;
  342.     for (int i = 0; i < top.l.size(); i++) {
  343.         sum += top.l.at(i) + 1;
  344.     }
  345.     z_count = sum;
  346.     for (int i = 0; i < top.l.size() - 1; i++) {
  347.         mult += (top.l.at(i) + 1)*top.l.at(i + 1);
  348.     }
  349.     a_arr = new double[sum];
  350.     z_arr = new double[sum];
  351.  
  352.     W = new int[top.l.size()];
  353.     sw = new int[top.l.size()];
  354.  
  355.     w_arr = new double[mult];
  356.     dw_arr = new double[mult];
  357.  
  358.     t_arr = new double[top.l.at(top.l.size() - 1)];
  359.  
  360.     gjl = new double[sum];
  361. }
  362.  
  363. void AnnSerialDBL::init(Topology top, double w_arr_1[] = NULL)
  364. {
  365.     L = top.l.size();
  366.     //Neuronu kiekiai sluoksnyje
  367.     for (int i = 0; i < top.l.size(); i++) {
  368.         l[i] = top.l.at(i) + 1;
  369.     }
  370.  
  371.     //Sluoksniu pradzios indeksai
  372.     for (int i = 0; i < top.l.size(); i++) {
  373.         s[i] = 0;
  374.         for (int j = i; j > 0; j--) {
  375.             s[i] += l[j - 1];
  376.         }
  377.     }
  378.  
  379.     //Bias neuronai
  380.     for (int i = 0; i < top.l.size() - 1; i++) {
  381.         a_arr[s[i + 1] - 1] = 1;
  382.     }
  383.  
  384.     //Svoriu kiekiai l-ame sluoksnyje
  385.     for (int i = 0; i < top.l.size() - 1; i++) {
  386.         W[i] = l[i] * (l[i + 1] - 1);
  387.         //cout << "Svoriu sk: " << W[i] << " Pradzios index: ";
  388.         sw[i] = 0;
  389.         if (i != 0) {
  390.             for (int j = 0; j < i; j++) {
  391.                 sw[i] += W[j];
  392.             }
  393.         }
  394.         if (w_arr_1 == NULL) {
  395.             for (int j = 0; j < W[i]; j++) {
  396.                 w_arr[sw[i] + j] = (double)rand() / double(RAND_MAX);
  397.                 //cout << w_arr[sw[i] + j]<< endl;
  398.                 dw_arr[sw[i] + j] = 0;
  399.             }
  400.         }
  401.         else {
  402.             //w_arr = w_arr_1; //ar reikia pokycius issisaugoti irgi?
  403.             for (int j = 0; j < W[i]; j++) {
  404.                 w_arr[sw[i] + j] = w_arr_1[sw[i] + j];
  405.                 //cout << w_arr[sw[i] + j]<< endl;
  406.                 dw_arr[sw[i] + j] = 0;
  407.             }
  408.         }
  409.  
  410.         //cout << sw[i] << " " << endl;
  411.     }
  412. }
  413.  
  414. void AnnSerialDBL::train(double *a, double *b)
  415. {
  416.     for (int i = 0; i < inputs; i++) {
  417.         a_arr[i] = a[i];
  418.     }
  419.  
  420.     for (int j = 0; j < z_count; j++) {
  421.         z_arr[j] = 0;
  422.     }
  423.  
  424.     calc_feedForward();
  425.  
  426.     for (int i = 0; i < outputs; i++) {
  427.         t_arr[i] = b[i];
  428.     }
  429.     calc_gjl(a_arr, z_arr, t_arr, w_arr, s, sw, L,l, gjl);
  430.  
  431.     //back propogation:
  432.     for (int i = L - 1; i >= 0; i--) {//per sluoksnius
  433.         for (int j = 0; j < l[i]; j++) {//per neuronus
  434.             for (int k = 0; k < l[i + 1] - 1; k++) {//per kito sluoksnio neuronus
  435.                 dw_arr[sw[i] + k + j*(l[i + 1] - 1)] = delta_w(w_gradient(i, j, k, a_arr, s,gjl), dw_arr[sw[i] + k + j*(l[i + 1] - 1)]);
  436.                 w_arr[sw[i] + k + j*(l[i + 1] - 1)] += dw_arr[sw[i] + k + j*(l[i + 1] - 1)];
  437.             }
  438.         }
  439.     }
  440.  
  441.     //del grafiku
  442.     if ((kiek_yra % 10 == 0)&&(kiek_yra<5000)) {
  443.         E_error[kiek_yra/10] = (a[0] == a[1]) ? abs(1 - a_arr[s[L - 1]]) : abs(1 - a_arr[s[L - 1] + 1]);
  444.         w_diffrence[kiek_yra/10] = abs(a_arr[s[L - 1]] - a_arr[s[L - 1] + 1]);
  445.     }
  446.     kiek_yra++;
  447.  
  448. }
  449.  
  450. void AnnSerialDBL::feedForward(double *a, double *b)
  451. {
  452.     for (int i = 0; i < inputs; i++) {
  453.         a_arr[i] = a[i];
  454.     }
  455.  
  456.     for (int j = 0; j < z_count; j++) {
  457.         z_arr[j] = 0;
  458.     }
  459.  
  460.     calc_feedForward();
  461.  
  462.     double max = 0;
  463.     int index = 0;
  464.  
  465.     for (int i = 0; i<outputs; i++) {
  466.         cout << " a reiksmes: " << a_arr[s[L - 1] + i] << endl;
  467.         if (max < a_arr[s[L - 1] + i]) {
  468.             max = a_arr[s[L - 1] + i];
  469.             index = i;
  470.         }
  471.     }
  472.     for (int i = 0; i < outputs; i++) {
  473.         if (i == index) {
  474.             b[i] = 1;
  475.         }
  476.         else {
  477.             b[i] = 0;
  478.         }
  479.     }
  480. }
  481.  
  482. void AnnSerialDBL::calc_feedForward()
  483. {
  484.     for (int i = 0; i <= L - 1; i++) {//per sluoksnius einu+
  485.         for (int j = 0; j < l[i]; j++) { //kiek neuronu sluoksnyje+
  486.             for (int k = 0; k < l[i + 1] - 1; k++) {//per sekancio sluoksnio z+
  487.                 z_arr[s[i + 1] + k] += w_arr[sw[i] + k + j*(l[i + 1] - 1)] * a_arr[s[i] + j];
  488.                 //  cout << "w: "<< w_arr[sw[i] + k + j*(l[i + 1] - 1)] << endl;
  489.                 //  cout << "a: " << a_arr[s[i] + j] << endl;
  490.                 //  cout << "z reiksmes: " << z_arr[s[i+1] + k] << endl;
  491.                 //  cout << endl;
  492.             }
  493.         }
  494.         for (int k = 0; k < l[i + 1] - 1; k++) {//per sekancio sluoksnio z
  495.             a_arr[s[i + 1] + k] = f(z_arr[s[i + 1] + k]);
  496.             //  cout << s[i + 1] + k << " a reiksmes: " << a_arr[s[i + 1] + k] << endl;
  497.         }
  498.     }
  499. }
  500.  
  501. void AnnSerialDBL::destroy()
  502. {
  503.     delete[] l;
  504.     l = NULL;
  505.     delete[] s;
  506.     s = NULL;
  507.  
  508.     delete[] a_arr;
  509.     a_arr = NULL;
  510.     delete[] z_arr;
  511.     z_arr = NULL;
  512.  
  513.     delete[] W;
  514.     W = NULL;
  515.     delete[] sw;
  516.     sw = NULL;
  517.  
  518.     delete[] w_arr;
  519.     w_arr = NULL;
  520.     delete[] dw_arr;
  521.     dw_arr = NULL;
  522.  
  523.     delete[] t_arr;
  524.     t_arr = NULL;
  525.  
  526.     delete[] gjl;
  527.     gjl = NULL;
  528. }
  529.  
  530. void file(string filename, double *a,int dydis) {
  531.     ofstream myfile;
  532.     myfile.open(filename);
  533.     myfile << filename << ";" << endl;
  534.     for (int i = 0; i < dydis; i++) {
  535.         myfile << a[i] << ";" << endl;
  536.     }
  537.     myfile.close();
  538. }
  539.  
  540. void print_all(AnnSerialDBL  &SerialDBL, int sum, int mult,int i) {
  541.     file("(" + to_string(5 * i + 1) + ")a_arr.csv", SerialDBL.a_arr, sum);
  542.     file("(" + to_string(5 * i + 2) + ")z_arr.csv", SerialDBL.z_arr, sum);
  543.     file("(" + to_string(5 * i + 3) + ")w_arr.csv", SerialDBL.w_arr, mult);
  544.     file("(" + to_string(5 * i + 4) + ")dw_arr.csv",SerialDBL.dw_arr, mult);
  545.     file("(" + to_string(5 * i + 5) + ")g_arr.csv", SerialDBL.gjl, sum);
  546. }
  547.  
  548. void read_W(string filename,double *w_arr) {
  549.     ifstream myReadFile;
  550.     myReadFile.open(filename);
  551.     string a;
  552.     if (myReadFile.is_open()) {
  553.         int i = 0;
  554.         while (!myReadFile.eof()) {
  555.             myReadFile >> a;
  556.             w_arr[i++] = stod(a);
  557.         }
  558.     }
  559.     myReadFile.close();
  560. }
  561.  
  562. void calc_sizes(int &sum, int &mult, Topology top) {
  563.     for (int i = 0; i < topolygy.l.size(); i++) {
  564.         sum += topolygy.l.at(i) + 1;
  565.     }
  566.     for (int i = 0; i < topolygy.l.size() - 1; i++) {
  567.         mult += (topolygy.l.at(i) + 1)*topolygy.l.at(i + 1);
  568.     }
  569. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement