Advertisement
mirosh111000

Використання нейронної мережi Хопфiлда для вiдновлення зображень(pr8)

Nov 3rd, 2024 (edited)
87
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 8.01 KB | None | 0 0
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. import datetime as dt
  4. import os
  5. from IPython.display import clear_output
  6. import pandas as pd
  7.  
  8. def generate_symbol(symbol, height, width):
  9.    
  10.     pattern = np.ones((height, width)) * -1
  11.  
  12.     if symbol == 'П':
  13.         pattern[:, 1] = 1
  14.         pattern[:, -2] = 1
  15.         pattern[0, 1:-2] = 1
  16.     elif symbol == 'Т':
  17.         pattern[:, int(height/2)] = 1
  18.         pattern[0, :] = 1
  19.     elif symbol == 'К':
  20.         pattern[:, 0] = 1
  21.         pattern[int(height/2), :int(width/2)] = 1
  22.         for _, k in enumerate(range(int(width/2), width-1)):
  23.             num = _ + 1
  24.             for i in [-num, num]:
  25.                 if k > int(height/2)+1:
  26.                     pattern[int(height/2)+i, k+1] = 1
  27.                 elif k == int(height/2)+1:
  28.                     pattern[int(height/2)+i, k:k+2] = 1
  29.                 else:
  30.                     pattern[int(height/2)+i, k] = 1
  31.     else:
  32.         return None
  33.        
  34.    
  35.     return pattern.reshape(M)
  36.  
  37. def visualize_comparison(original, noisy, recalled, title=''):
  38.         fig, axes = plt.subplots(1, 3, figsize=(9, 3))
  39.         for ax, img, title in zip(axes, [original, noisy, recalled], ["Оригінал", title, "Відновлений"]):
  40.             ax.imshow(img.reshape(height, width), cmap="gray_r")
  41.             ax.set_title(title)
  42.             ax.axis("off")
  43.         plt.show()
  44.  
  45. def visualize_symbols(symbols):
  46.  
  47.     fig, axes = plt.subplots(1, len(symbols), figsize=(7, 5))
  48.     for ax, (symbol, pattern) in zip(axes, symbols.items()):
  49.         ax.imshow(pattern.reshape(height, width), cmap="gray_r")
  50.         ax.set_title(symbol)
  51.         ax.axis("off")
  52.     plt.show()
  53.  
  54.  
  55. class Hopfield():
  56.     def __init__(self, M, max_iter=100):
  57.         self.M = M
  58.         self.max_iter = max_iter
  59.         self.weights = None
  60.  
  61.     def fit(self, symbols):
  62.         W = np.zeros((self.M, self.M))
  63.         for symbol_vector in symbols.values():
  64.             W += np.outer(symbol_vector, symbol_vector)
  65.         np.fill_diagonal(W, 0)
  66.         self.weights =  W / self.M
  67.  
  68.     def hard_threshold_activation(self, x):
  69.         return np.where(x > 0, 1, -1)
  70.  
  71.     def predict(self, input_vector):
  72.         y = np.copy(input_vector)
  73.         for _ in range(self.max_iter):
  74.             y_new = self.hard_threshold_activation(self.weights @ y)
  75.             if np.array_equal(y, y_new):
  76.                 break
  77.             y = y_new
  78.         return y
  79.  
  80.  
  81. def blackout_noise(image, blackout_prob=0.2):
  82.     noisy_image = np.copy(image)
  83.     black_pixels = np.where(image == 1)
  84.     num_black_pixels = len(black_pixels[0])
  85.    
  86.     for i in range(num_black_pixels):
  87.         random_num = np.random.random()
  88.         if random_num <= blackout_prob:
  89.             noisy_image[black_pixels[0][i]] = -1
  90.            
  91.     return noisy_image
  92.  
  93. def add_random_noise(image, noise_prob=0.1):
  94.  
  95.     black_pixels = np.where(image == 1)
  96.     num_black_pixels = len(black_pixels[0])
  97.    
  98.     noisy_image = blackout_noise(image, blackout_prob=noise_prob)
  99.  
  100.     for i in range(num_black_pixels):
  101.         random_num = np.random.random()
  102.         if random_num <= noise_prob:
  103.             random_idx = np.random.randint(0, len(noisy_image))
  104.             noisy_image[random_idx] = 1
  105.     return noisy_image
  106.  
  107. def generate_noisy_symbols(method='blackout', blackout_prob=0.2, noise_prob=0.2):
  108.     noisy_symbols = {}
  109.     for i in ['П', 'Т', 'К']:
  110.         if method == 'blackout':
  111.             noisy_symbols[str(i)] = blackout_noise(symbols[str(i)], blackout_prob)
  112.         elif method == 'add_random':
  113.             noisy_symbols[str(i)] = add_random_noise(symbols[str(i)], noise_prob)
  114.     return noisy_symbols
  115.  
  116.  
  117.  
  118.  
  119.  
  120. width = 9
  121. height = 9
  122. M = width * height
  123.  
  124. symbols = {str(i): generate_symbol(i, height, width) for i in ['П', 'Т', 'К']}
  125.  
  126. visualize_symbols(symbols)
  127.  
  128. noisy_symbols_blackout = generate_noisy_symbols(method='blackout', blackout_prob=0.2)
  129. noisy_symbols_random = generate_noisy_symbols(method='add_random', noise_prob=0.2)
  130.  
  131. visualize_symbols(noisy_symbols_blackout)
  132. visualize_symbols(noisy_symbols_random)
  133.  
  134. model = Hopfield(M=M, max_iter=100)
  135. model.fit(symbols)
  136.  
  137. for s in ['П', 'Т', 'К']:
  138.     test_symbol = symbols[s]
  139.    
  140.     noisy_test_symbol = add_random_noise(test_symbol, noise_prob=0.2)
  141.     blackout_test_symbol = blackout_noise(test_symbol, blackout_prob=0.2)
  142.    
  143.     recalled_symbol_blackout = model.predict(noisy_test_symbol)
  144.     recalled_symbol_noise = model.predict(blackout_test_symbol)
  145.    
  146.     visualize_comparison(test_symbol, blackout_test_symbol, recalled_symbol_blackout, 'Затирання')
  147.     visualize_comparison(test_symbol, noisy_test_symbol, recalled_symbol_noise, 'Шум')
  148.  
  149.  
  150. n = 9999
  151. df = pd.DataFrame(columns=['blackout_input', 'blackout_prob', 'noise_input', 'noise_prob', 'target', 'symbol', 'blackout_pred', 'blackout_error, %', 'blackout_class_error', 'noise_pred', 'noise_error, %', 'noise_class_error'], index=[i for i in range(n)])
  152. symbol = 'П'
  153.  
  154. df_res = pd.DataFrame(columns=['blackout_error, %', 'noise_error, %', 'blackout_class_error, %', 'noise_class_error, %'])
  155. df_res.index.name = 'prob, %'
  156. num_prob = 250
  157. start_time = dt.datetime.now()
  158.  
  159. for _, prob in enumerate(np.linspace(0, 1, num_prob)):
  160.    
  161.     for i in range(n):
  162.    
  163.         df['symbol'].iloc[i] = symbol
  164.         df['target'].iloc[i] = symbols[symbol]
  165.         df['blackout_prob'].iloc[i] = prob
  166.         df['noise_prob'].iloc[i] = prob
  167.    
  168.         df['blackout_input'].iloc[i] = blackout_noise(symbols[symbol], blackout_prob=prob)
  169.         df['noise_input'].iloc[i] = add_random_noise(symbols[symbol], noise_prob=prob)
  170.  
  171.         df['blackout_pred'].iloc[i] = model.predict(df['blackout_input'].iloc[i])
  172.         df['noise_pred'].iloc[i] = model.predict(df['noise_input'].iloc[i])
  173.  
  174.         df['blackout_error, %'].iloc[i] = np.sum(np.abs(df['target'].iloc[i] - df['blackout_pred'].iloc[i])) / len(df['target'].iloc[i]) * 100
  175.         df['noise_error, %'].iloc[i] = np.sum(np.abs(df['target'].iloc[i] - df['noise_pred'].iloc[i])) / len(df['target'].iloc[i]) * 100
  176.        
  177.         df['blackout_error, %'].iloc[i] = np.sum(np.abs(df['target'].iloc[i] - df['blackout_pred'].iloc[i])) / len(df['target'].iloc[i]) * 100
  178.         df['noise_error, %'].iloc[i] = np.sum(np.abs(df['target'].iloc[i] - df['noise_pred'].iloc[i])) / len(df['target'].iloc[i]) * 100
  179.  
  180.         df['blackout_class_error'].iloc[i] = np.array_equal(df['target'].iloc[i], df['blackout_pred'].iloc[i])
  181.         df['noise_class_error'].iloc[i] = np.array_equal(df['target'].iloc[i], df['noise_pred'].iloc[i])
  182.    
  183.         if symbol == 'П': symbol = 'Т'
  184.         elif symbol == 'Т': symbol = 'К'
  185.         elif symbol == 'К': symbol = 'П'
  186.  
  187.     df_res.loc[prob*100] = [df['blackout_error, %'].mean(), df['noise_error, %'].mean(),
  188.                             (1 - np.sum(df['blackout_class_error']) / len(df['blackout_class_error'])) * 100,
  189.                             (1 - np.sum(df['noise_class_error']) / len(df['noise_class_error'])) * 100]
  190.  
  191.     prcnt = (_+1)/num_prob * 100
  192.     print(f'№{_+1}/{num_prob} - {round(prcnt, 2)}% | total time: {dt.datetime.now() - start_time} | time remaining: {(dt.datetime.now() - start_time) / prcnt * (100 - prcnt)}', end='\r')
  193.     os.system('cls' if os.name == 'nt' else 'clear')
  194.  
  195. # clear_output()
  196.  
  197.  
  198. plt.figure(figsize=(10, 4))
  199. plt.plot(df_res.index, df_res.iloc[:, 0], c='blue', label='Затирання')
  200. plt.plot(df_res.index, df_res.iloc[:, 1], c='red', label='Шум')
  201. plt.legend()
  202. plt.title('Відсоток неправильно відновлених клітинок')
  203. plt.xlabel('Noise level [%]')
  204. plt.ylabel('Percentage of error, [%]')
  205. plt.grid()
  206. plt.show()
  207.  
  208. plt.figure(figsize=(10, 4))
  209. plt.plot(df_res.index, df_res.iloc[:, 2], c='blue', label='Затирання')
  210. plt.plot(df_res.index, df_res.iloc[:, 3], c='red', label='Шум')
  211. plt.legend()
  212. plt.title('Відсоток неправильно відновлених символів')
  213. plt.xlabel('Noise level [%]')
  214. plt.ylabel('Percentage of error, [%]')
  215. plt.grid()
  216. plt.show()
  217.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement