Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package org.example;
- import javax.swing.*;
- import java.awt.*;
- import java.awt.event.ActionEvent;
- import java.awt.event.ActionListener;
- public class Main extends JPanel implements ActionListener {
- // Константы
- private static final int N = 100;
- private static final double TAU = 0.05;
- private static final double TAU_STEP = 0.005;
- private static final double D_H = 0.01;
- private static final double ALPHA = 0.005;
- private static final int R = 10;
- private static final double START_X = -7.0; // Начальное значение c_x
- private static final double END_X = 1.0; // Конечное значение c_x
- private static final double A0 = 0.45;
- private static final double A1 = 3.06;
- private static final double A2 = 8.08;
- private static final double A3 = 13.06;
- private static final double A4 = 6.03;
- private static final double A5 = -3.8;
- // Переменные нейросети
- private double c_x;
- private double[] u_t;
- private double[] u;
- private double[] u_n;
- private double[] g;
- private int n_t;
- private int step;
- private int err;
- private double learningPercentage;
- // Графические параметры
- private double graphW, graphH, graphDx, graphBx, graphRx;
- // Таймер для анимации
- private Timer timer;
- public Main() {
- // Инициализация переменных
- c_x = START_X; // Начинаем с -7
- u_t = new double[N];
- u = new double[N];
- u_n = new double[N];
- g = new double[R];
- for (int i = 0; i < N; i++) {
- u[i] = 0.0;
- u_t[i] = 0.0;
- u_n[i] = 0.0;
- }
- for (int i = 0; i < R; i++) {
- g[i] = 0.0;
- }
- n_t = N / 2 + 1; // Обучаемая точка в середине
- step = 1;
- err = 0;
- learningPercentage = 0.0;
- // Запуск таймера (обновление каждые 10 мс)
- timer = new Timer(10, this);
- timer.start();
- }
- // Функция sign(x)
- private double sign(double x) {
- if (x > 0) return 1.0;
- else if (x < 0) return -1.0;
- return 0.0;
- }
- // Функция f(x)
- private double f(double x) {
- return A0 * Math.pow(x, 5) + A1 * Math.pow(x, 4) + A2 * Math.pow(x, 3)
- + A3 * Math.pow(x, 2) + A4 * x + A5;
- }
- // Функция обучения
- private void teach(double[] u, double[] g, int n_t, double delta) {
- double[] delta_g = new double[R];
- for (int r = 0; r < R; r++) {
- double x_r = u[n_t - 1 - r];
- delta_g[r] = ALPHA * sign(x_r) * sign(delta);
- }
- for (int r = 0; r < R; r++) {
- g[r] += delta_g[r];
- }
- err++;
- }
- // Вычисление процента обучения
- private void calculateLearningPercentage() {
- double totalError = 0.0;
- double minVal = u[0], maxVal = u[0];
- for (int i = 0; i < N; i++) {
- totalError += Math.abs(u_n[i] - u_t[i]);
- if (u[i] < minVal) minVal = u[i];
- if (u[i] > maxVal) maxVal = u[i];
- }
- double avgError = totalError / N;
- double range = maxVal - minVal;
- if (range < 1e-9) range = 1.0;
- learningPercentage = 100 * (1 - avgError / range);
- if (learningPercentage < 0) learningPercentage = 0;
- if (learningPercentage > 100) learningPercentage = 100;
- }
- // Выполнение одного шага
- private void newStep() {
- // Новое положение F
- for (int i = 0; i < N; i++) {
- u_n[i] = f(c_x + i * TAU);
- }
- c_x += TAU_STEP;
- // Сброс c_x при выходе за пределы
- if (c_x > END_X) {
- c_x = START_X;
- }
- // Прогноз u_t
- for (int i = 0; i < N; i++) {
- u_t[i] = 0.0;
- }
- for (int i = R; i < N; i++) {
- for (int r = 0; r < R; r++) {
- u_t[i] += g[r] * u[i - 1 - r];
- }
- }
- for (int i = 0; i < R; i++) {
- u_t[i] = u[i];
- }
- // Обучение
- double d_u = u_n[n_t] - u_t[n_t];
- if (d_u < -D_H || d_u > D_H) {
- teach(u, g, n_t, d_u);
- }
- step++;
- }
- // Обновление переменных
- private void updateVars() {
- newStep();
- for (int i = 0; i < N; i++) {
- u[i] = u_n[i];
- }
- calculateLearningPercentage();
- }
- @Override
- protected void paintComponent(Graphics gCanvas) {
- super.paintComponent(gCanvas);
- Graphics2D g2 = (Graphics2D) gCanvas;
- int panelW = getWidth();
- int panelH = getHeight();
- // Инициализация графических параметров
- graphW = panelW;
- graphH = panelH;
- graphDx = graphW / N;
- graphBx = (graphW - graphDx * (N - 1)) / 2.0;
- graphRx = graphBx + (N - 1) * graphDx;
- // Масштабирование
- double minVal = u[0], maxVal = u[0];
- for (int i = 1; i < N; i++) {
- if (u[i] < minVal) minVal = u[i];
- if (u[i] > maxVal) maxVal = u[i];
- }
- double mid = (maxVal + minVal) / 2.0;
- double diff = maxVal - minVal;
- minVal = mid - diff * 1.5;
- maxVal = mid + diff * 1.5;
- if (Math.abs(maxVal - minVal) < 1e-9) {
- minVal -= 1.0;
- maxVal += 1.0;
- }
- double localMtt = graphH / (maxVal - minVal);
- double y_c = localMtt * maxVal;
- // Рисуем оси
- g2.setColor(Color.BLACK);
- g2.drawLine((int)graphBx, 0, (int)graphBx, (int)graphH);
- g2.drawLine((int)graphRx, 0, (int)graphRx, (int)graphH);
- g2.drawLine((int)graphBx, 0, (int)graphRx, 0);
- g2.drawLine((int)graphBx, (int)graphH, (int)graphRx, (int)graphH);
- g2.drawLine((int)graphBx, (int)y_c, (int)graphRx, (int)y_c);
- for (int i = 1; i < N - 1; i++) {
- int x = (int)(graphBx + i * graphDx);
- g2.drawLine(x, (int)(y_c - 3), x, (int)(y_c + 3));
- }
- int x_nt = (int)(graphBx + (n_t - 1) * graphDx);
- g2.drawLine(x_nt, 0, x_nt, (int)graphH);
- // Рисуем кривую u (зелёная)
- g2.setColor(Color.GREEN.darker());
- int oldX = (int)graphBx;
- int oldY = (int)(y_c - localMtt * u[0]);
- for (int i = 1; i < N; i++) {
- int x = (int)(graphBx + i * graphDx);
- int y = (int)(y_c - localMtt * u[i]);
- g2.drawLine(oldX, oldY, x, y);
- oldX = x;
- oldY = y;
- }
- // Рисуем прогноз u_t (красная)
- g2.setColor(Color.RED);
- oldX = (int)graphBx;
- oldY = (int)(y_c - localMtt * u_t[0]);
- for (int i = 1; i < N; i++) {
- int x = (int)(graphBx + i * graphDx);
- int y = (int)(y_c - localMtt * u_t[i]);
- g2.drawLine(oldX, oldY, x, y);
- oldX = x;
- oldY = y;
- }
- // Рисуем текст
- g2.setColor(Color.BLACK);
- int textX = 20;
- int textY = 20;
- g2.drawString("Step = " + step, textX, textY); textY += 16;
- g2.drawString(String.format("dH = %.2f", D_H), textX, textY); textY += 16;
- g2.drawString(String.format("dG = %.3f", ALPHA), textX, textY); textY += 16;
- for (int i = 0; i < R; i++) {
- g2.drawString(String.format("G[%d] = %.3f", i, g[i]), textX, textY);
- textY += 16;
- }
- g2.drawString(String.format("Learning: %.2f%%", learningPercentage), textX, textY);
- textY += 16;
- g2.drawString(String.format("c_x = %.3f", c_x), textX, textY); // Добавим c_x для отладки
- }
- @Override
- public Dimension getPreferredSize() {
- return new Dimension(800, 600);
- }
- @Override
- public void actionPerformed(ActionEvent e) {
- updateVars();
- repaint();
- }
- public static void main(String[] args) {
- SwingUtilities.invokeLater(() -> {
- JFrame frame = new JFrame("Прогнозирование функции (нейросеть)");
- frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
- Main panel = new Main();
- frame.setContentPane(panel);
- frame.pack();
- frame.setLocationRelativeTo(null);
- frame.setVisible(true);
- });
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement