Advertisement
nagoL2015

NeuralNetwork

Jan 20th, 2019
131
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. const LOG_ON = false;
  2. const LOG_FREQ = 10000;
  3.  
  4. /*********************
  5. |   Neural Network
  6. *********************/
  7.  
  8. class NeuralNetwork {
  9.   constructor(numInputs, numHidden, numOutputs) {
  10.     this._inputs = [];
  11.     this._hidden = [];
  12.     this._numInputs = numInputs;
  13.     this._numHidden = numHidden;
  14.     this._numOutputs = numOutputs;
  15.     this._bias0 = new Matrix(1, this._numHidden);
  16.     this._bias1 = new Matrix(1, this._numOutputs);
  17.     this._weights0 = new Matrix(this._numInputs, this._numHidden);
  18.     this._weights1 = new Matrix(this._numHidden, this._numOutputs);
  19.  
  20.     this._logCount = LOG_FREQ;
  21.  
  22.     this._bias0.randomWeights();
  23.     this._bias1.randomWeights();
  24.     this._weights0.randomWeights();
  25.     this._weights1.randomWeights();
  26.   }
  27.  
  28.   get inputs() {
  29.     return this._inputs;
  30.   }
  31.  
  32.   set inputs(inputs) {
  33.     this._inputs = inputs;
  34.   }
  35.  
  36.   get hidden() {
  37.     return this._hidden;
  38.   }
  39.  
  40.   set hidden(hidden) {
  41.     this._hidden = hidden;
  42.   }
  43.  
  44.   get bias0() {
  45.     return this._bias0;
  46.   }
  47.  
  48.   set bias0(bias0) {
  49.     this._bias0 = bias0;
  50.   }
  51.  
  52.   get bias1() {
  53.     return this._bias1;
  54.   }
  55.  
  56.   set bias1(bias1) {
  57.     this._bias1 = bias1;
  58.   }
  59.  
  60.   get weights0() {
  61.     return this._weights0;
  62.   }
  63.  
  64.   set weights0(weights) {
  65.     this._weights0 = weights;
  66.   }
  67.  
  68.   get weights1() {
  69.     return this._weights1;
  70.   }
  71.  
  72.   set weights1(weights) {
  73.     this._weights1 = weights;
  74.   }
  75.  
  76.   get logCount() {
  77.     return this._logCount;
  78.   }
  79.  
  80.   set logCount(logCount) {
  81.     this._logCount = logCount;
  82.   }
  83.  
  84.   feedForward(inputArray) {
  85.     this.inputs = Matrix.convertFromArray(inputArray);
  86.  
  87.     this.hidden = Matrix.dot(this.inputs, this.weights0);
  88.     this.hidden = Matrix.add(this.hidden, this.bias0);
  89.     this.hidden = Matrix.map(this.hidden, x => sigmoid(x));
  90.  
  91.     let outputs = Matrix.dot(this.hidden, this.weights1);
  92.     outputs = Matrix.add(outputs, this.bias1);
  93.     outputs = Matrix.map(outputs, x => sigmoid(x));
  94.  
  95.     return outputs;
  96.   }
  97.  
  98.   run(inputArray) {
  99.     return this.feedForward(inputArray).data;
  100.   }
  101.  
  102.   train(inputArray, targetArray) {
  103.     let outputs = this.feedForward(inputArray);
  104.  
  105.     let targets = Matrix.convertFromArray(targetArray);
  106.     let outputErrors = Matrix.subtract(targets, outputs);
  107.  
  108.     if (LOG_ON) {
  109.       if (this.logCount == LOG_FREQ) {
  110.         console.log('error = ' + outputErrors.data[0][0]);
  111.       }
  112.       this.logCount--;
  113.       if (this.logCount == 0) {
  114.         this.logCount = LOG_FREQ;
  115.       }
  116.     }
  117.  
  118.     let outputDerivs = Matrix.map(outputs, x => sigmoid(x, true));
  119.     let outputDeltas = Matrix.multiply(outputErrors, outputDerivs);
  120.  
  121.     let weights1T = Matrix.transpose(this.weights1);
  122.     let hiddenErrors = Matrix.dot(outputDeltas, weights1T);
  123.  
  124.     let hiddenDerivs = Matrix.map(this.hidden, x => sigmoid(x, true));
  125.     let hiddenDeltas = Matrix.multiply(hiddenErrors, hiddenDerivs);
  126.  
  127.     let hiddenT = Matrix.transpose(this.hidden);
  128.     this.weights1 = Matrix.add(this.weights1, Matrix.dot(hiddenT, outputDeltas));
  129.     let inputsT = Matrix.transpose(this.inputs);
  130.     this.weights0 = Matrix.add(this.weights0, Matrix.dot(inputsT, hiddenDeltas));
  131.  
  132.     this.bias1 = Matrix.add(this.bias1, outputDeltas);
  133.     this.bias0 = Matrix.add(this.bias0, hiddenDeltas);
  134.   }
  135. }
  136.  
  137. function sigmoid(x, deriv = false) {
  138.   if (deriv) {
  139.     return x * (1 - x);
  140.   }
  141.   return 1 / (1 + Math.exp(-x));
  142. }
  143.  
  144. /***********************
  145. |   Matrix Functions
  146. ***********************/
  147.  
  148. class Matrix {
  149.   constructor(rows, cols, data = []) {
  150.     this._rows = rows;
  151.     this._cols = cols;
  152.     this._data = data;
  153.  
  154.     if (data == null || data.length == 0) {
  155.       this._data = [];
  156.       for (let i = 0; i < this._rows; i++) {
  157.         this._data[i] = [];
  158.         for (let j = 0; j < this._cols; j++) {
  159.           this._data[i][j] = 0;
  160.         }
  161.       }
  162.     } else {
  163.       if (data.length != rows || data[0].length != cols) {
  164.         throw new Error("Incorrect data dimensions!");
  165.       }
  166.     }
  167.   }
  168.  
  169.   get rows() {
  170.     return this._rows;
  171.   }
  172.  
  173.   get cols() {
  174.     return this._cols;
  175.   }
  176.  
  177.   get data() {
  178.     return this._data;
  179.   }
  180.  
  181.   static add(m0, m1) {
  182.     Matrix.checkDimensions(m0, m1);
  183.     let m = new Matrix(m0.rows, m0.cols);
  184.     for (let i = 0; i < m.rows; i++) {
  185.       for (let j = 0; j < m.cols; j++) {
  186.         m.data[i][j] = m0.data[i][j] + m1.data[i][j];
  187.       }
  188.     }
  189.     return m;
  190.   }
  191.  
  192.   static checkDimensions(m0, m1) {
  193.     if (m0.rows != m1.rows || m0.cols != m1.cols) {
  194.       throw new Error('Matrices are different dimensions!');
  195.     }
  196.   }
  197.  
  198.   static convertFromArray(arr) {
  199.     return new Matrix(1, arr.length, [arr])
  200.   }
  201.  
  202.   static dot(m0, m1) {
  203.     if (m0.cols != m1.rows) {
  204.       throw new Error("Matrices are not \"dot\" compatible!");
  205.     }
  206.     let m = new Matrix(m0.rows, m1.cols);
  207.     for (let i = 0; i < m.rows; i++) {
  208.       for (let j = 0; j < m.cols; j++) {
  209.         let sum = 0;
  210.         for (let k = 0; k < m0.cols; k++) {
  211.           sum += m0.data[i][k] * m1.data[k][j];
  212.         }
  213.         m.data[i][j] = sum;
  214.       }
  215.     }
  216.     return m;
  217.   }
  218.  
  219.   static map(m0, mFunction) {
  220.     let m = new Matrix(m0.rows, m0.cols);
  221.     for (let i = 0; i < m.rows; i++) {
  222.       for (let j = 0; j < m.cols; j++) {
  223.         m.data[i][j] = mFunction(m0.data[i][j]);
  224.       }
  225.     }
  226.     return m;
  227.   }
  228.  
  229.   static multiply(m0, m1) {
  230.     Matrix.checkDimensions(m0, m1);
  231.     let m = new Matrix(m0.rows, m0.cols);
  232.     for (let i = 0; i < m.rows; i++) {
  233.       for (let j = 0; j < m.cols; j++) {
  234.         m.data[i][j] = m0.data[i][j] * m1.data[i][j];
  235.       }
  236.     }
  237.     return m;
  238.   }
  239.  
  240.   static subtract(m0, m1) {
  241.     Matrix.checkDimensions(m0, m1);
  242.     let m = new Matrix(m0.rows, m0.cols);
  243.     for (let i = 0; i < m.rows; i++) {
  244.       for (let j = 0; j < m.cols; j++) {
  245.         m.data[i][j] = m0.data[i][j] - m1.data[i][j];
  246.       }
  247.     }
  248.     return m;
  249.   }
  250.  
  251.   static transpose(m0) {
  252.     let m = new Matrix(m0.cols, m0.rows);
  253.     for (let i = 0; i < m0.rows; i++) {
  254.       for (let j = 0; j < m0.cols; j++) {
  255.         m.data[j][i] = m0.data[i][j];
  256.       }
  257.     }
  258.     return m;
  259.   }
  260.  
  261.   randomWeights() {
  262.     for (let i = 0; i < this.rows; i++) {
  263.       for (let j = 0; j < this.cols; j++) {
  264.         this.data[i][j] = Math.random() * 2 - 1;
  265.       }
  266.     }
  267.   }
  268. }
  269.  
  270. module.exports = { Matrix, NeuralNetwork };
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement